我对tf.train.string_input_producer
的工作原理有些疑问。因此,假设我将filename_list作为输入参数提供给string_input_producer
。然后,根据文档https://www.tensorflow.org/programmers_guide/reading_data,这将创建一个FIFOQueue
,我可以在其中设置时代号,随机播放文件名等。因此,就我而言,我有4个文件名(“db1.tfrecords”,“db2.tfrecords”...)。我使用tf.train.batch
来提供网络批次的图像。另外,每个文件名/数据库都包含一组人员的图像。第二个数据库是针对第二个人的,等等。到目前为止,我有以下代码:用张量流确定tf.train.string_input_producer的时代数
tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
(common + "P21_db.tfrecords")]
filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'annotation_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image, [height, width, 3])
annotation = tf.cast(features['annotation_raw'], tf.string)
min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
shapes=[[], [112, 112, 3]],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
最后,尝试在自动编码器的输出,查看了重建图像的时候,我得到了第一个从第一数据库中的图像,然后我开始从观看图像第二个数据库等等。
我的问题:我怎么知道我是否在同一个时代?如果我处于理智的时代,我如何合并一批来自我拥有的所有file_names的图像?
最后,我试图通过如下Session
内评估局部变量打印出时代的价值:
epoch_var = tf.local_variables()[0]
然后:
with tf.Session() as sess:
print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.
任何帮助深表感谢!
你可以使用'tf.python_io.tf_record_iterator'来计算记录的数量,并给出批量的大小,你应该得到当前的epoch编号。虽然没有得到你的第二个问题。 –
@vijaym,这不是我所问的。我有'tf.train.string_input_producer',而不是'tf.python_io.tf_record_iterator'。 –