2016-11-27 38 views
0

我已经在张量流中创建了一个神经网络。这个网络是多标签的。 Ergo:它试图预测一个输入集的多个输出标签,在这种情况下为三个。目前我使用此代码来测试我的网络是如何准确地预测了三个标签:测试张量流网络:in_top_k()替换多标签分类

_, indices_1 = tf.nn.top_k(prediction, 3) 
_, indices_2 = tf.nn.top_k(item_data, 3) 
correct = tf.equal(indices_1, indices_2) 
accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
percentage = accuracy.eval({champion_data:input_data, item_data:output_data}) 

该代码工作正常。问题是现在我试图创建代码来测试index_1中找到的前3个项目是否位于indices_2中的前5个图像之中。我知道tensorflow有一个in_top_k()方法,但据我所知不接受multilabel。目前我一直在尝试使用一个for循环来对它们进行比较:

_, indices_1 = tf.nn.top_k(prediction, 5) 
_, indices_2 = tf.nn.top_k(item_data, 3) 
indices_1 = tf.unpack(tf.transpose(indices_1, (1, 0))) 
indices_2 = tf.unpack(tf.transpose(indices_2, (1, 0))) 
correct = [] 
for element in indices_1: 
    for element_2 in indices_2: 
     if element == element_2: 
      correct.append(True) 
     else: 
      correct.append(False) 
accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
percentage = accuracy.eval({champion_data:input_data, item_data:output_data}) 

然而,这是行不通的。代码运行但我的准确度始终为0.0。

所以我有两个问题之一:

1)有一个简单的替代in_top_k()接受,我可以使用,而不是自定义编写代码多标签分类?

2)如果不是1:我做错了什么导致我得到0.0的准确性?

回答

0

当你

correct = tf.equal(indices_1, indices_2) 

要检查不只是这两个指标是否包含相同的元素,但它们是否包含在相同的位置相同的元素。这听起来不像你想要的。

setdiff1d op会告诉你哪些索引在indices_1中,但不在indices_2中,然后可以用它来计算错误。

我认为过于严格的正确性检查可能是什么导致你得到一个错误的结果。

+0

非常感谢!这是朝着正确方向迈出的一大步。它确实需要我更新我的tensorflow,因为我的版本还没有setdiff1d。你介意如何计算错误吗?我已经尝试了一些东西,但似乎无法弄清楚如何知道setdif1d找到的许多差异。 –