2017-08-15 81 views
1

我对tf.train.string_input_producer的工作原理有些疑问。因此,假设我将filename_list作为输入参数提供给string_input_producer。然后,根据文档https://www.tensorflow.org/programmers_guide/reading_data,这将创建一个FIFOQueue,我可以在其中设置时代号,随机播放文件名等。因此,就我而言,我有4个文件名(“db1.tfrecords”,“db2.tfrecords”...)。我使用tf.train.batch来提供网络批次的图像。另外,每个文件名/数据库都包含一组人员的图像。第二个数据库是针对第二个人的,等等。到目前为止,我有以下代码:用张量流确定tf.train.string_input_producer的时代数

tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"), 
          (common + "P21_db.tfrecords")] 

filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue') 
reader = tf.TFRecordReader() 

key, serialized_example = reader.read(filename_queue) 
features = tf.parse_single_example(
    serialized_example, 
    # Defaults are not specified since both keys are required. 
    features={ 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'annotation_raw': tf.FixedLenFeature([], tf.string) 
    }) 

image = tf.decode_raw(features['image_raw'], tf.uint8) 
height = tf.cast(features['height'], tf.int32) 
width = tf.cast(features['width'], tf.int32) 

image = tf.reshape(image, [height, width, 3]) 

annotation = tf.cast(features['annotation_raw'], tf.string) 

min_after_dequeue = 100 
num_threads = 4 
capacity = min_after_dequeue + num_threads * batch_size 
label_batch, images_batch = tf.train.batch([annotation, image], 
                 shapes=[[], [112, 112, 3]], 
                 batch_size=batch_size, 
                 capacity=capacity, 
                 num_threads=num_threads) 

最后,尝试在自动编码器的输出,查看了重建图像的时候,我得到了第一个从第一数据库中的图像,然后我开始从观看图像第二个数据库等等。

我的问题:我怎么知道我是否在同一个时代?如果我处于理智的时代,我如何合并一批来自我拥有的所有file_names的图像?

最后,我试图通过如下Session内评估局部变量打印出时代的价值:

epoch_var = tf.local_variables()[0] 

然后:

with tf.Session() as sess: 
    print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y. 

任何帮助深表感谢!

+0

你可以使用'tf.python_io.tf_record_iterator'来计算记录的数量,并给出批量的大小,你应该得到当前的epoch编号。虽然没有得到你的第二个问题。 –

+0

@vijaym,这不是我所问的。我有'tf.train.string_input_producer',而不是'tf.python_io.tf_record_iterator'。 –

回答

0

所以我想到的是,使用tf.train.shuffle_batch_join解决了我的问题,因为它开始洗牌不同数据集的图像。换句话说,每个批次现在都包含来自所有数据集/文件名的图像。这里有一个例子:

def read_my_file_format(filename_queue): 
    reader = tf.TFRecordReader() 
    key, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(
     serialized_example, 
     # Defaults are not specified since both keys are required. 
     features={ 
      'height': tf.FixedLenFeature([], tf.int64), 
      'width': tf.FixedLenFeature([], tf.int64), 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'annotation_raw': tf.FixedLenFeature([], tf.string) 
     }) 

    # This is how we create one example, that is, extract one example from the database. 
    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    # The height and the weights are used to 
    height = tf.cast(features['height'], tf.int32) 
    width = tf.cast(features['width'], tf.int32) 

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the 
    # height and the weight to restore the original image back. 
    image = tf.reshape(image, [height, width, 3]) 

    annotation = tf.cast(features['annotation_raw'], tf.string) 
    return annotation, image 

def input_pipeline(filenames, batch_size, num_threads, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False, 
                name='queue') 
    # Therefore, Note that here we have created num_threads readers to read from the filename_queue. 
    example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)] 
    min_after_dequeue = 100 
    capacity = min_after_dequeue + num_threads * batch_size 
    label_batch, images_batch = tf.train.shuffle_batch_join(example_list, 
                  shapes=[[], [112, 112, 3]], 
                  batch_size=batch_size, 
                  capacity=capacity, 
                  min_after_dequeue=min_after_dequeue) 
    return label_batch, images_batch, example_list 

label_batch, images_batch, input_ann_img = \ 
    input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch) 

现在这个是要打造一批读者从FIFOQueue阅读,每个读者后都会有不同的解码器。最后,在解码图像之后,它们将被馈送到在调用tf.train.shuffle_batch_join之后创建的另一个Queue中以向网络馈送一批图像。