当我想用tf.train.string_input_producer
加载数据2个时代,我用如何使用TensorFlow tf.train.string_input_producer生成多个时代数据?
filename_queue = tf.train.string_input_producer(filenames=['data.csv'], num_epochs=2, shuffle=True)
col1_batch, col2_batch, col3_batch = tf.train.shuffle_batch([col1, col2, col3], batch_size=batch_size, capacity=capacity,\min_after_dequeue=min_after_dequeue, allow_smaller_final_batch=True)
但后来我发现,这种运算没有产生我想要的。
它只能生成data.csv
中的每个样品2次,但生成的顺序不明确。例如,3个data.csv
[[1]
[2]
[3]]
线数据就会产生(其中每个样品只出现2次,但该命令是可选的)
[1]
[1]
[3]
[2]
[2]
[3]
但我想是(每个历元是分开,洗牌在每个时间段)
此外,如何知道什么时候1个时代做?有一些标志变量吗?谢谢!
我的代码在这里。
import tensorflow as tf
def read_my_file_format(filename_queue):
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [['1'], ['1'], ['1']]
col1, col2, col3 = tf.decode_csv(value, record_defaults=record_defaults, field_delim='-')
# col1 = list(map(int, col1.split(',')))
# col2 = list(map(int, col2.split(',')))
return col1, col2, col3
def input_pipeline(filenames, batch_size, num_epochs=1):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
col1,col2,col3 = read_my_file_format(filename_queue)
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batch_size
col1_batch, col2_batch, col3_batch = tf.train.shuffle_batch(
[col1, col2, col3], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue, allow_smaller_final_batch=True)
return col1_batch, col2_batch, col3_batch
filenames=['1.txt']
batch_size = 3
num_epochs = 1
a1,a2,a3=input_pipeline(filenames, batch_size, num_epochs)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
a, b, c = sess.run([a1, a2, a3])
print(a, b, c)
except tf.errors.OutOfRangeError:
print('Done training, epoch reached')
finally:
coord.request_stop()
coord.join(threads)
我的数据是一样
1,2-3,4-A
7,8-9,10-B
12,13-14,15-C
17,18-19,20-D
22,23-24,25-E
27,28-29,30-F
32,33-34,35-G
37,38-39,40-H
您可以添加生成张量'col1','col2','col3'的代码?代码被写入的方式表明你在流水线结束时洗牌,因此它将全部混在一起 – MZHm
我添加了我的代码和数据。@ MZHm – danche
你可能想看看这个答案,看看是否有类似的问题: https://stackoverflow.com/a/44526962/4282745 – npf