2017-04-18 126 views
3

我对Tensorflow完全陌生。我一直在尝试重新编写Deep MNIST教程来预测MovieLens数据集上的电影收视率。我略微简化了模型,以便不使用5分制,而是简单的二进制Y/N分级(类似于Netflix上的最新评级体系)。我试图仅使用部分评分来预测新项目的偏好。当训练模型,我得到的堆栈跟踪以下错误:使用SoftmaxCrossEntropyWithLogits登录和标签必须具有相同的大小错误

Traceback (most recent call last): 
    File "/Users/Eric/dev/Coding Academy >Tutorials/tf_impl/deep_tf_group_rec_SO.py", line 223, in <module> 
    train_step.run(feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.5}) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/framework/ops.py", line 1550, in run 
    _run_using_default_session(self, feed_dict, self.graph, session) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/framework/ops.py", line 3764, in >_run_using_default_session 
    session.run(operation, feed_dict) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/client/session.py", line 767, in run 
    run_metadata_ptr) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/client/session.py", line 965, in _run 
    feed_dict_string, options, run_metadata) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/client/session.py", line 1015, in _do_run 
    target_list, options, run_metadata) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/client/session.py", line 1035, in _do_call 
    raise type(e)(node_def, op, message) 
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and >labels must be same size: logits_size=[1,2] labels_size=[50,2] 
    [[Node: SoftmaxCrossEntropyWithLogits = >SoftmaxCrossEntropyWithLogits[T=DT_FLOAT, >_device="/job:localhost/replica:0/task:0/cpu:0"](Reshape_2, Reshape_3)]] 

Caused by op u'SoftmaxCrossEntropyWithLogits', defined at: 
    File "/Users/Eric/dev/Coding Academy >Tutorials/tf_impl/deep_tf_group_rec_SO.py", line 209, in <module> 
    cross_entropy = >tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, >logits=y_conv)) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/ops/nn_ops.py", line 1617, in >softmax_cross_entropy_with_logits 
    precise_logits, labels, name=name) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/ops/gen_nn_ops.py", line 2265, in >_softmax_cross_entropy_with_logits 
    features=features, labels=labels, name=name) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/framework/op_def_library.py", line 763, in >apply_op 
    op_def=op_def) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/framework/ops.py", line 2327, in create_op 
    original_op=self._default_original_op, op_def=op_def) 
    File "/Library/Python/2.7/site->packages/tensorflow/python/framework/ops.py", line 1226, in __init__ 
    self._traceback = _extract_stack() 

InvalidArgumentError (see above for traceback): logits and labels must >be same size: logits_size=[1,2] labels_size=[50,2] 
    [[Node: SoftmaxCrossEntropyWithLogits = >SoftmaxCrossEntropyWithLogits[T=DT_FLOAT, >_device="/job:localhost/replica:0/task:0/cpu:0"](Reshape_2, Reshape_3)]] 

代码导致错误可以被视为在模型中使用的变量here

尺寸:

  • X( ?,1682)

  • Y_(?,2)

  • X _history(?,290,290,1)
  • h_pool1(?,145,145,32)
  • h_pool2(?,73,73,64)
  • h_pool3(?,37,37,128)
  • h_pool4(?,19,19,256)
  • h_pool5(?,10,10,512)
  • h_fc1(?,1024)
  • h_fc1_drop(?,1024)
  • y_conv(?, 2)
+0

我创建了一个要点:https://gist.github.com/EricSEkong/eaa67da30390a4eb2d50c282f3a2e4c7 –

+0

如果是二进制评分,那么为什么标签尺寸是50x3? – Aaron

+0

哦,伙计!如此愚蠢。我会解决这个问题,看看。谢谢 –

回答

1

的问题是,你重塑你输入批次(形状50个训练实例X 1682层的功能)为[-1,290,290,1]在这一行:

x_history = tf.reshape(x, [-1, 290, 290, 1]) 

你可以看到形状x_history

x_history.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}).shape 

=> (1, 290, 290, 1) 

这有效地把所有的功能为您的一批50个实例,并把它在一个单一实例(第一个维度是1,其中它需要50),然后运行:运行该结束了通过网络的其余部分。 因此,您的cross-tratropy评估失败,因为它无法将单个输出的50个目标标签与网络对齐。

您需要选择图层形状,以便通过网络保留批量维度(形状打印输出中的?)。做到这一点的一种方法是将你的特征填充到1764,并重新设置为[-1,42,42,1],因为42 * 42 = 1764。

值得注意的是,2d卷积最常用于自然为二维的图像数据。鉴于您的功能不是二维的,您可能会更好地从一个更简单的完全连接层网络开始?

0

为了突破这个问题,我结束了下降,从50批次大小一个

batch = create_batches(train_data_pairs, 1) 

,并显着提高训练的迭代次数。此外,为了准确性,我对大量小片测试数据测试模型,然后采取评估的意思。

然而,这不仅仅是解决问题的方法,它让我继续探索TensorFlow的其他方面,并以不同的方式使用数据。

相关问题