2016-11-25 65 views
1

描述here加载一些训练图像分批,即,基本上是这样的:培训VS测试与我使用的是设置队列

def read_my_file_format(filename_queue): 
    # ... use a reader + a decoder 

def input_pipeline(filenames, batch_size, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(...) 
    example, label = read_my_file_format(filename_queue) 
    example_batch, label_batch = tf.train.shuffle_batch(
     [example, label], batch_size=batch_size, ...) 
    return example_batch, label_batch 

def build_net(): 
    batch, label = input_pipeline(...) 
    y = encoder(batch) # <- build network using the batch 

def train(): 
    with tf.Session() as sess: 
    # ... init vars 

    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    try: 
     while not coord.should_stop(): 
     # ... training step 

    except tf.errors.OutOfRangeError: 
     print('Done training -- epoch limit reached') 
    finally: 
     coord.request_stop() 

    coord.join(threads) 
    sess.close() 

这就是很好的训练 - 但是,我怎么没看到我可以测试最终的网络!什么使我困惑:

  • input_pipeline返回的张量是网络的一部分。为了测试,我将不得不更换它?
  • 我想我可以创建另一个input_pipeline进行测试,即使用不同的文件名队列。然后我可以使用tf.cond在不同的输入批次之间切换,但是:如何确保一次只有一个队列耗尽。我看不到如何访问不同的队列以及如何指定它们如何卸载。

基本上,这个问题归结为:什么是测试网络的规范方式使用tf.train.shuffle_batch方法构建。

回答

1

你是绝对正确的轨道创造了附加的输入管道的想法上数据集评估。使用multiple input pipelines是推荐的方法之一,其将由两个过程组成 - 一方面是训练,另一方面是评估。检查点将在训练过程中使用,然后每千步骤,代码可以尝试针对训练数据集和测试数据集两者的模型eval

从文档报价:

  • 训练过程训练读取输入数据,并定期与所有训练的变量写检查点文件。
  • 评估过程将检查点文件恢复为读取验证输入数据的推理模型。

即使在培训完成/退出后也可以进行评估。 (see this example

另一个考虑是通过sharing variables train和eval可以在同一个过程中在同一个图中操作,同时分享他们训练过的变量!

关于您拥有的队列耗尽问题,如果您使用tf.train.shuffle_batch*将num_threads设置为大于1,它将同时从单个文件读取(+比使用1个线程更快),而不是同时读取N个文件(请参阅关于batching的部分)。

+0

听起来不错,我现在仔细看看这个 – fabian789

1

我的想法是使用一个字符串占位符,即,假设你有多个输入文件:

filenames_place = tf.placeholder(tf.string, shape=[None]) 
num_epochs_place = tf.placeholder(tf.int32) 
example_batch, label_batch = input_pipeline(filenames_place, batch_size, num_epochs_place) 
... 
try: 
    sess.run(train_op, feed_dict={filenames_place: ["train_data1", "train_data2"], num_epochs_place=5}) 

except tf.errors.OutOfRangeError: 
    print('Done training -- epoch limit reached') 

sess.run(eval_op, feed_dict={filenames_place: ["test_data"], num_epochs_place=1}) 
+0

这实际上工作吗?我觉得'string_input_producer'创建后无法更改文件名,但不确定 – fabian789