2017-06-13 83 views
0

我期待使用1.2中可用的new Dataset API,但在应用简单的map转换时会遇到问题,该转换会在index table中查找单词。新的dataset.map转换和查找表:不兼容的字符串类型

考虑一个简单的例子:

import tensorflow as tf 

mapping_strings = tf.constant(["emerson", "lake", "palmer"]) 
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=mapping_strings, num_oov_buckets=1) 

dataset = tf.contrib.data.Dataset.from_tensor_slices(
    tf.constant(["emerson", "lake"])) 

# Here is the map operation that generates an error. 
dataset = dataset.map(lambda x: table.lookup(x)) 

iterator = dataset.make_one_shot_iterator() 
next_element = iterator.get_next() 

with tf.Session() as sess: 
    sess.run(tf.tables_initializer()) 
    sess.run(next_element) 

随着1.2.0-rc2,它会生成以下错误:

TypeError: In op 'string_to_index_Lookup/hash_table_Lookup', input types ([tf.string, tf.string, tf.int64]) are not compatible with expected types ([tf.string_ref, tf.string, tf.int64]) 

查找表需要一个tf.string_ref这个规定并没有得到满足。

由于我是TensorFlow的新手,我不认为这是一个错误,而是一个糟糕的用法。我的错误是什么?

谢谢!

编辑2017年6月15日:随着nightly版本,但是,它抛出另一个错误:

ValueError: Cannot capture a stateful node (name:string_to_index/hash_table, type:HashTableV2) by value. 

回答

3

您可能需要使用Dataset.make_initializable_iterator(),而不是Dataset.make_one_shot_iterator()因为哈希表状态。

下面的代码为我工作:

import tensorflow as tf 

mapping_strings = tf.constant(["emerson", "lake", "palmer"]) 
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=mapping_strings, num_oov_buckets=1) 

dataset = tf.contrib.data.Dataset.from_tensor_slices(
    tf.constant(["emerson", "lake"])) 

# Here is the map operation that generates an error. 
dataset = dataset.map(lambda x: table.lookup(x)) 

iterator = dataset.make_initializable_iterator() 
init_op = iterator.initializer 

with tf.Session() as sess: 
    sess.run(tf.tables_initializer()) 
    sess.run(init_op) 
+0

我相信当你升级到TensorFlow的最新(夜间)版本的工作原理长。在1.2.0的最新版本候选版本中存在一个bug,其中一些废弃的内核(使用旧式'tf.string_ref'而不是'tf.resource')继续在'tf.contrib.lookup.index_table_from_tensor )',现在已经修复了这些问题,以便在主分支中使用新版本。 – mrry

+0

谢谢你们两位。所以解决方法是升级TensorFlow并使用迭代器初始值设定项。我更新了我的问题,以反映用TensorFlow版本得到的错误,而没有提到错误@mrry,我接受了Satoshi的答案。 – guillaumekln

+0

tf.string_ref错误仍然存​​在于1.2.0 final中。 – guillaumekln

相关问题