2017-10-16 61 views
2

我跑了一些教程代码text classificationTensorflow:解决

我可以运行这些脚本tf.estimator.inputs.numpy_input_fn功能和它的工作,但是当我试图通过线试图了解每一行运行步是干什么的,我有点困惑,在此步骤:

test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={WORDS_FEATURE: x_test}, 
    y=y_test, 
    num_epochs=1, 
    shuffle=False) 
classifier.train(input_fn=train_input_fn, steps=100) 

我知道概念train_input_fn是喂养数据的训练功能,但我怎么可以手动调用这个FN检查什么在它?

我跟踪的代码,发现了train_input_fn功能将资料提供至以下两个变量:

features 
Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>} 

labels 
Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32> 

当我试图做一个sess.run(功能)来评价的特征变量,我终端似乎卡住,停止响应。

什么是正确的方式来检查这些变量的内容?

谢谢!

回答

1

基于numpy_input_fn documentation和行为(挂起),我想底层的实现取决于队列运行。队列运行程序未启动时发生挂起。尝试修改您的会话中运行的脚本类似于以下,基于this guide

with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    try: 
     for step in xrange(1000000): 
      if coord.should_stop(): 
       break 
      features_data = sess.run(features) 
      print(features_data) 

    except Exception, e: 
     # Report exceptions to the coordinator. 
     coord.request_stop(e) 
    finally: 
     # Terminate as usual. It is safe to call `coord.request_stop()` twice. 
     coord.request_stop() 
     coord.join(threads) 

或者,我会鼓励你检查出tf.data.Dataset接口(可能tf.contrib.data.Dataset在tensorflow 1.3或之前)。您可以使用Dataset.from_tensor_slices而不使用队列来获得类似的输入/标签张量。创建稍微有点牵扯,但接口更加灵活,实现不使用队列运行器,这意味着会话运行更加简单。

import tensorflow as tf 
import numpy as np 

x_data = np.random.random((100000, 2)) 
y_data = np.random.random((100000,)) 

batch_size = 2 
buff = 100 


def input_fn(): 
    # possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier 
    dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) 
    dataset = dataset.repeat().shuffle(buff).batch(batch_size) 
    x, y = dataset.make_one_shot_iterator().get_next() 
    return x, y 


x, y = input_fn() 
with tf.Session() as sess: 
    print(sess.run([x, y])) 
+0

谢谢DomJack。代码在Python 3上进行了一些小改动。我不认为在Tensorflow中打印出张量的值是非常复杂的。 – Allen

+0

这是因为队列运行器的实现。我编辑了我的答案,以包含可能会有帮助的数据集示例。 '数据集'是比较新的,但是一旦你通过了一些样板,我发现它们非常简单,强大而且快速。 – DomJack

+0

谢谢@DomJack。我一定会检查出来的。我发现它有时非常直观地调试Tensorflow代码。 – Allen