Tensorflow 共享变量理解

本文主要通过一些例子介绍tensorflow中两种变量的声明方式和两种scope

参考

概述: https://www.tensorflow.org/programmers_guide/variables

tf.Variable(): https://www.tensorflow.org/api_docs/python/tf/Variable

tf.get_variable(): https://www.tensorflow.org/api_docs/python/tf/get_variable

tf.variable_scope(): https://www.tensorflow.org/api_docs/python/tf/variable_scope

tf.name_scope(): https://www.tensorflow.org/api_docs/python/tf/name_scope

简介

本文主要通过一些例子介绍tensorflow中两种变量的声明方式

  • tf.Variable()
  • tf.get_variable()

以及两种scope:

  • tf.name_scope()
  • tf.variable_scope()

tf.name_scope() 和 tf.variable_scope()

  • 例子1

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import tensorflow as tf
    with tf.name_scope("Ph0en1x_name_scope"):
    initializer = tf.constant_initializer(value=1)
    var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer)
    var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
    var3 = var2*2

    with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(var1.name)
    print(var2.name)
  • 输出1

    1
    2
    3
    var1:0
    Ph0en1x_name_scope/var2:0
    Ph0en1x_name_scope/mul

    可以看出使用tf.Variable()创建的变量和算子op受到name_scope的控制,而使用tf.get_variable()的则不会

  • 例子2

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import tensorflow as tf
    with tf.variable_scope("Ph0en1x_variable_scope"):
    initializer = tf.constant_initializer(value=1)
    var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer)
    var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
    var3 = var2*2

    with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(var1.name)
    print(var2.name)
  • 输出2

    1
    2
    3
    Ph0en1x_variable_scope/var1:0
    Ph0en1x_variable_scope/var2:0
    Ph0en1x_variable_scope/mul:0

    如果使用variable_scope 则都会受到控制

tf.Variable() 和 tf.get_variable()

  • 例子3

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    import tensorflow as tf
    with tf.variable_scope("Ph0en1x_variable_scope"):
    initializer = tf.constant_initializer(value=1)
    var1 = tf.Variable(name='var1', initial_value=[2], dtype=tf.float32)
    var1_1 = tf.Variable(name='var1', initial_value=[2], dtype=tf.float32)

    with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(var1.name)
    print(var1_1.name)
  • 输出3

    1
    2
    Ph0en1x_variable_scope/var1:0
    Ph0en1x_variable_scope/var1_1:0

    当在同一个scope中使用tf.Variable()声明同名的变量时,tensorflow会自动将其重命名,但是,

  • 例子4

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import tensorflow as tf
    with tf.variable_scope("Ph0en1x_variable_scope"):
    initializer = tf.constant_initializer(value=1)

    var2 = tf.get_variable(name='var2', shape=[1], dtype=tf.float32, initializer=initializer)
    var2_2 = tf.get_variable(name='var2', shape=[1], dtype=tf.float32, initializer=initializer)

    with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(var2.name)
    print(var2_2.name)
  • 结果4

    1
    ValueError: Variable Ph0en1x_variable_scope/var2 already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

    使用tf.get_variable()时则没有自动重命名,而是直接抛出异常,询问是否想要把reuse值设为True

reuse参数

通过tf.get_variable() tf.variable_scope() 以及reuse参数的配合使用,就可以使用共享变量。

为什么要用共享变量?

​ 举一个栗子,当我们搭建对抗生成网络进行判别器的训练的时候,对判别器是有真实图像和生成图像两组输入的,即

  • 判别器定义

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def get_discriminator(img, n_units, reuse=False, alpha=0.01):

    with tf.variable_scope('discriminator', reuse=reuse):

    hidden1 = tf.layers.dense(img, n_units)
    hidden1 = tf.maximum(alpha*hidden1, hidden1)

    logits = tf.layers.dense(hidden1, 1)
    outputs = tf.sigmoid(logits)
    return logits, outputs
  • 判别器实例

    1
    2
    d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
    d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)

    两组图像使用的都是同一组判别器的变量,但是他们分两次调用get_discriminator()时,如果不将reuse设置成为True,显然就会产生上面出现的异常。

  • 例子5

    1
    2
    3
    4
    5
    6
    import tensorflow as tf
    with tf.variable_scope("foo"):
    v = tf.get_variable("var", [1])
    with tf.variable_scope("foo", reuse=True):
    v1 = tf.get_variable("var", [1])
    print(v1 is v)

    1
    2
    3
    4
    5
    with tf.variable_scope("foo"):
    v = tf.get_variable("var", [1])
    tf.get_variable_scope().reuse_variables()
    v1 = tf.get_variable("var", [1])
    print(v1 is v)
  • 结果5

    1
    True

    但是一旦设置为True,就无法再次修改回去