2017-08-04 64 views
0

这个张量流代码没有响应,我找不出原因。请帮忙!Tensorflow以无限循环结束

import tensorflow as tf 
#reading the file 
with tf.name_scope ('File_reading') as scope: 
    filename_queue = tf.train.string_input_producer(["forestfires.csv.digested"]) 
    reader = tf.TextLineReader() 
    key, value = reader.read(filename_queue) 
    record_defaults = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0]] 
    #13 decoded 
    col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13 = tf.decode_csv(
     value, record_defaults=record_defaults) 


    #12 is feture, and the 13th is the training data 
    features = tf.stack([col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12],name='data_input') 

    with tf.Session() as sess: 
     # Start populating the filename queue. 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     for i in range(517): 
      # Retrieve a single instance: 
      example, label = sess.run([features, col13]) 

     coord.request_stop() 
     coord.join(threads) 
with tf.name_scope ('network') as scope: 
    W1=tf.Variable(tf.zeros([12, 8]), name='W1') 
    b1=tf.Variable(tf.zeros([8]), name="b1") 
    h1=tf.add(tf.matmul(tf.expand_dims(features,0), W1),b1, name='hidden_layer') 
    W2=tf.Variable(tf.zeros([8, 1]), name='W2') 
    b2=tf.Variable(tf.zeros([1]), name="b2") 
    output=tf.add(tf.matmul(h1, W2),b2, name='output_layer') 
error=tf.add(output,-col13, name='error') 
#training 
train_step = tf.train.AdamOptimizer(1e-4).minimize(error) 
#graphing the output 
file_writer = tf.summary.FileWriter('some directory', sess.graph) 
with tf.Session() as sess: 
    #init 
    tf.global_variables_initializer().run() 
    print ("\n\n\n\n\n\nTRAINING STARTED\n\n\n\n\n\n") 
    print('test1') 
    sess.run(error) #this statement causes an infinite loop 
    print ('test2') 
file_writer.close() 

该代码运行并打印'test1',但它什么都不做,甚至没有响应ctrl + c。我试图查找问题,但是我的谷歌技能不够好,或者它不在互联网上。 system:win10 geforce 960M python 3.5.2

回答

0

您构建网络的方式在智力上并不会使敏感。如果您需要从TextLineReader读取517个步骤,请使用函数read_up_to并提供值517,而不是使用单独的会话。按照您构建图表的方式,输入阅读器与图形的其余部分之间似乎没有一个简洁的连接。

我的建议:

# define graph which includes the input queue 
def model(...): 
... 
    return error, metrics 

with tf.Graph.as_default(): 
    error, metrics = model(...) 

    with tf.Session(): 
    # Start Coordinator 
    # Initialise global vars 
    # Start queue runners 
    # model_error, model_metrics = sess.run([error, metrics]) 
0

解决它(这个错误),这不是一个无限循环,它只是等待输入数据。出于某种原因,如果我将上面的'with tf.Session()as sess:'块(不带with部分)粘贴到块的顶部,它会很好地运行。 (也许有可能,还有一些其他编码错误,因为自那以后我改变了一些东西。)