前言:
在Tensorflow中我们想共享变量的时候,需要在一个name_scope下,网络中的一些参数有时利用也是如此,尤其是在RNN中,接下来向大家举例!
代码如下:
#这里的每个变量都不一样,没有共享
with tf.name_scope("a_name_scope"):
initializer = tf.constant_initializer(value=1)
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer)
var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
var21 = tf.Variable(name='var2', initial_value=[2.1], dtype=tf.float32)
#这里的var3 重复利用,共享变量
with tf.variable_scope("a_variable_scope") as scope:
initializer = tf.constant_initializer(value=3)
var3 = tf.get_variable(name='var3', shape=[1], dtype=tf.float32, initializer=initializer)
scope.reuse_variables()
var3_reuse = tf.get_variable(name='var3',)
#这里RNN,重复利用训练的参数例子
with tf.variable_scope('rnn') as scope:
sess = tf.Session()
train = RNN(train_config)
scope.reuse_variables()
test = RNN(test_config)
sess.run(tf.global_variables_initializer())
欢迎大家学习交流! ;)