2017-07-24 126 views
1

我想通过在Tensorflow中排队我的数据(x_data和y_target)来开发随机批生成。这里我的代码:随机批生成器Tensorflow

x_data=np.asarray(x_data) # shape (2404,60,41,2) 
y_target=np.asarray(y_data) # shape (2404,7) 

queue_x_data = tf.placeholder(tf.float64, shape=[20,60,41,2]) 
queue_y_target = tf.placeholder(tf.int64, shape=[20,7]) 
q = tf.FIFOQueue(capacity=50, dtypes=[tf.float64,tf.int64], shapes=[(60,41,2), (7)]) 
init = q.enqueue_many([queue_x_data, queue_y_target]) 
dequeue_op = q.dequeue() 
data_batch, target_batch = tf.train.shuffle_batch(dequeue_op,batch_size=32,num_threads=4,capacity=50000, min_after_dequeue=10000) 

with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
    curr_data_batch, curr_target_batch = sess.run([data_batch, target_batch]) 

在这里,我会尝试只是打印生成的批处理。

print(curr_data_batch) 
    coord.request_stop() 
    coord.join(threads) 

但我无法得到任何生成的批处理打印。我究竟做错了什么?由于

回答

0

有2个原因。

  • 首先,您从未打电话给init操作:sess.run(init, {queue_x_data: x_data, queue_y_target: y_target}))
  • 第二,min_after_dequeue=10000是与队列中元素的数量进行比较,尝试将其设置为较低的级别。
0

您尝试从占位符填满你的队列,但占位符必须用数据通过转移来填充free_dict参数然后调用sess.run()