2017-05-24 86 views
2

我面临的问题无法用我在互联网上找到的解决。Tensorflow:在输入队列中训练和测试同一个图形

我建立了我的神经网络,并将它连接到输入管道。从tfrecord 读取数据,与tf.train.batch和queueRunners,COORDS等。

我建立我的NN到一个名为 “模型”,我喜欢使用Python类:

模型=模型(...所有超参数在这里...)

...

model.predict()

model.step()

所有训练阶段工作得很好。

但是现在我想在每个X时代/训练阶段添加一个测试阶段。

我真的不知道该怎么做。 我有几个想法,但我没有找到最好的一个:

  • 复制代码到我的类来获得:loss_train和loss_test,等我图中的每个节点? (使用共享变量火车和测试之间)
  • 创建我模型的第2实例:

model_train =模型(重用=假)

model_test =模型(重用=真)

  • use tf.make_template?我真的没有发现这个功能的任何好例子...
  • 任何其他解决方案?

我将不胜感激任何建议,

回答

1

与TFRecords数据集进行实验,当我遇到同样的问题来了。有几种可能性。因为我想只有一个GPU做到这一点的计算机上反正我实现它,如下所示:

# Training Dataset 
train_dataset = tf.contrib.data.TFRecordDataset(train_files) 
train_dataset = train_dataset.map(parse_function) 
train_dataset = train_dataset.shuffle(buffer_size=10000) 
train_dataset = train_dataset.batch(200) 
# Validation Dataset 
validation_dataset = tf.contrib.data.TFRecordDataset(val_files) 
validation_dataset = validation_dataset.map(parse_function) 
validation_dataset = validation_dataset.batch(200) 

# A feedable iterator is defined by a handle placeholder and its structure. We 
# could use the `output_types` and `output_shapes` properties of either 
# `training_dataset` or `validation_dataset` here, because they have 
# identical structure. 
handle = tf.placeholder(tf.string, shape=[]) 
iterator = tf.contrib.data.Iterator.from_string_handle(handle, 
train_dataset.output_types, train_dataset.output_shapes) 
next_element = iterator.get_next() 

# Generate the Iterators 
training_iterator = train_dataset.make_initializable_iterator() 
validation_iterator = validation_dataset.make_one_shot_iterator() 

# The `Iterator.string_handle()` method returns a tensor that can be evaluated 
# and used to feed the `handle` placeholder. 
training_handle = sess.run(training_iterator.string_handle()) 
validation_handle = sess.run(validation_iterator.string_handle()) 

然后访问的元素,你可以去这样的:

img, lbl = sess.run(next_element, feed_dict={handle: training_handle}) 

和交流手柄取决于你愿意做什么ATM。

请记住,这不是可并行化的,但是。在此链接之后,您可以深入了解创建多个输入管道的不同方法Tensorflow | Reading Data