2017-02-22 42 views

回答

2

考虑使用tf.TextLineReader,它与tf.train.string_input_producer一起允许您从磁盘上的多个文件(如果您的数据集足够大以至于需要将其分散到多个文件中)加载数据。

https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files

代码段从上面的链接:

filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

# Default values, in case of empty columns. Also specifies the type of the 
# decoded result. 
record_defaults = [[1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.stack([col1, col2, col3, col4]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for  filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

# Default values, in case of empty columns. Also specifies the type of the 
# decoded result. 
record_defaults = [[1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.stack([col1, col2, col3, col4]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(1200): 
    # Retrieve a single instance: 
    example, label = sess.run([features, col5]) 

    coord.request_stop() 
    coord.join(threads)i in range(1200): 
    # Retrieve a single instance: 
    example, label = sess.run([features, col5]) 

    coord.request_stop() 
    coord.join(threads) 
+0

谢谢您的anwser。但是,如果CSV文件中有**列**,该怎么办?我必须写很多col1,col2,col3 ...等等?以及如何从二进制文件读取数据? – secsilm

+0

@secsilm是的,您需要在您的CSV中为每列添加“col1”,“col2”等。记住'col1'只是一个变量名,所以你可以给它一个更多的助记符名称,比如'price'或者其他什么。有关二进制文件,请参阅https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader – Insectatorious

0

通常情况下,您无论如何都会使用批处理智能培训,因此您可以即时加载数据。例如,对于图像:

for bid in nrBatches: 
    batch_x, batch_y = load_data_from_hd(bid) 
    train_step.run(feed_dict={x: batch_x, y_: batch_y}) 

因此,您可以实时加载每个批次,只加载需要在任何特定时刻加载的数据。当然你的训练时间会增加,而使用硬盘代替内存来加载数据。

相关问题