2016-09-06 1806 views
1

我想在火车组(is_training=True)和验证集(is_training=False)上运行给定模型,具体说明如何应用dropout。现在,prebuilt models公开了一个参数is_training,在构建网络时,该参数传递给dropout层。问题是,如果我使用不同的值is_training两次调用方法,我会得到两个不分享权重的网络(我认为?)。我怎样才能让两个网络共享相同的权重,以便我可以运行我在验证集上训练过的网络?带有is_training True和False的Tensorflow(tf-slim)模型

+0

我觉得默认的行为是共享两种情况之间的权重,所以你不要有什么关系。 'tf-slim'使用'tf.get_variable()'来重用调用之间的变量。 –

+0

好的,我认为这主要是有效的。你需要确保'scope'被设置,然后为了安全,最好也设置'reuse = True'。 –

回答

1

我写了一个解决方案,您的评论在列车和测试模式中使用Overfeat。 (我无法测试,所以你可以检查是否正常工作?)

一是部分进口及参数:

import tensorflow as tf 
slim = tf.contrib.slim 
overfeat = tf.contrib.slim.nets.overfeat 

batch_size = 32 
inputs = tf.placeholder(tf.float32, [batch_size, 231, 231, 3]) 
dropout_keep_prob = 0.5 
num_classes = 1000 

在训练模式,我们通过正常范围到功能overfeat

scope = 'overfeat' 
is_training = True 

output = overfeat.overfeat(inputs, num_classes, is_training,   
          dropout_keep_prob, scope=scope) 

然后在测试模式下,我们创建了与reuse=True相同的范围。

scope = tf.VariableScope(reuse=True, name='overfeat') 
is_training = False 

output = overfeat.overfeat(inputs, num_classes, is_training,   
          dropout_keep_prob, scope=scope) 
0

你可以只使用一个占位符is_training:

isTraining = tf.placeholder(tf.bool) 

# create nn 
net = ... 
net = slim.dropout(net, 
        keep_prob=0.5, 
        is_training=isTraining) 
net = ... 

# training 
sess.run([net], feed_dict={isTraining: True}) 

# testing 
sess.run([net], feed_dict={isTraining: False}) 
+1

我试过这个,并且遇到了问题,因为变量没有被重用。我也遇到了我无法解释的内存限制。 –

0

这要看情况下,解决方案是不同的。

我的第一个选择是使用不同的流程来进行评估。你只需要检查是否有新的关卡和加载权纳入评价网络(与is_training=False):

checkpoint = tf.train.latest_checkpoint(self.checkpoints_path) 
# wait until a new check point is available 
while self.lastest_checkpoint == checkpoint: 
    time.sleep(30) # sleep 30 seconds waiting for a new checkpoint 
    checkpoint = tf.train.latest_checkpoint(self.checkpoints_path) 
logging.info('Restoring model from {}'.format(checkpoint)) 
self.saver.restore(session, checkpoint) 
self.lastest_checkpoint = checkpoint 

第二个选项是每一个时代后您卸载图形,并创建一个新的评价用图。这个解决方案浪费了很多时间加载和卸载图形。

第三个选项是分享权重。但是给这些网络添加队列或数据集可能会导致问题,所以您必须非常小心。我只用于连体网络。

with tf.variable_scope('the_scope') as scope: 
    your_model(is_training=True) 
    scope.reuse_variables() 
    your_model(is_training=False)