2017-07-18 42 views
0

我想问一下tf.one_hot()函数是否支持SparseTensor作为“indices”参数。我想要做一个多标签分类(每个例子都有多个标签),这需要计算一个cross-trape丢失。tf.one_hot()是否支持SparseTensor作为索引参数?

我试图直接把SparseTensor在“指数”参数,但它提出了以下错误:

类型错误:未能类型的对象转换为张量。内容:SparseTensor(indices = Tensor(“read_batch_features/fifo_queue_dequeue:106”,shape =(?, 2),dtype = int64,device =/job:worker),values = Tensor(“string_to_index_Lookup:0”,shape =(? ,),dtype = int64,device =/job:worker),dense_shape = Tensor(“read_batch_features/fifo_queue_dequeue:108”,shape =(2,),dtype = int64,device =/job:worker))。考虑将元素转换为支持的类型。

任何有关可能原因的建议?

谢谢。

回答

0

one_hot不支持SparseTensor作为indices参数。您可以通过稀疏张量的索引/值张量作为索引参数,这可能会解决您的问题。

0

您可以从最初的SparseTensor构建另一个形状为(batch_size, num_classes)的SparseTensor。例如,如果你把你的班级在一个字符串特征柱(用空格隔开),可以使用下列内容:

import tensorflow as tf 

all_classes = ["class1", "class2", "class3"] 
classes_column = ["class1 class3", "class1 class2", "class2", "class3"] 

table = tf.contrib.lookup.index_table_from_tensor(
    mapping=tf.constant(all_classes) 
) 
classes = tf.constant(classes_column) 
classes = tf.string_split(classes) 
idx = table.lookup(classes) # SparseTensor of shape (4, 2), because each of the 4 rows has at most 2 classes 
num_items = tf.cast(tf.shape(idx)[0], tf.int64) # num items in batch 
num_entries = tf.shape(idx.indices)[0] # num nonzero entries 

y = tf.SparseTensor(
    indices=tf.stack([idx.indices[:, 0], idx.values], axis=1), 
    values=tf.ones(shape=(num_entries,), dtype=tf.int32), 
    dense_shape=(num_items, len(all_classes)), 
) 
y = tf.sparse_tensor_to_dense(y, validate_indices=False) 

with tf.Session() as sess: 
    tf.tables_initializer().run() 
    print(sess.run(y)) 

    # Outputs: 
    # [[1 0 1] 
    # [1 1 0] 
    # [0 1 0] 
    # [0 0 1]] 

这里idx是SparseTensor。其索引idx.indices[:, 0]的第一列包含批次的行号,其值idx.values包含相关类ID的索引。我们结合这两个来创建新的y.indices

要全面实施多标签分类,请参见https://stackoverflow.com/a/47671503/507062的“选项2”