2017-07-06 88 views
1

我正在查看this tutorial以创建带有Tensorflow的卷积神经网络。如何在新数据上运行卷积神经网络

神经网络建立并训练后,在本教程中,测试是这样的:

eval_results = mnist_classifier.evaluate(
    x=eval_data, y=eval_labels, metrics=metrics) 
print(eval_results) 

不过,我并没有对测试集的标签,所以我想它运行只是训练的例子,像这样:

eval_results = mnist_classifier.evaluate(x=test_data, metrics=metrics) 

如果我这样做,但是,我得到这样的警告,然后停止执行:

WARNING:tensorflow:From ../src/script.py:169: calling BaseEstimator.evaluate (from tensorflow.contrib.learn.python.learn.estimators.estimator) with x is deprecated and will be removed after 2016-12-01. 
Instructions for updating: 
Estimator is decoupled from Scikit Learn interface by moving into 
separate class SKCompat. Arguments x, y and batch_size are only 
available in the SKCompat class, Estimator will only accept input_fn. 
Example conversion: 
    est = Estimator(...) -> est = SKCompat(Estimator(...)) 
Traceback (most recent call last): 
    File "../src/script.py", line 172, in <module> 
    main() 
    File "../src/script.py", line 169, in main 
    eval_results = mnist_classifier.evaluate(x=test_data, metrics=metrics) 
    File "/opt/conda/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 289, in new_func 
    return func(*args, **kwargs) 
    File "/opt/conda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 530, in evaluate 
102.5s 
7 
    return SKCompat(self).score(x, y, batch_size, steps, metrics) 
    File "/opt/conda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1365, in score 
    name='score') 
    File "/opt/conda/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 816, in _evaluate_model 
    % self._model_dir) 
tensorflow.contrib.learn.python.learn.estimators._sklearn.NotFittedError: Couldn't find trained model at /tmp/mnist_convnet_model. 

回答

1

您可以使用mnist_classifier.evaluate没有y paramater,因为你有什么评价。

相反,使用y = mnist_classifier.predict(x=x)得到结果,并且看看他们自己知道,如果他们是正确与否。

然而,这是一个非常糟糕的想法是,网络已经培训了数据要做到这一点,因为这可能会导致不反映网络是如何处理新信息的好成绩。

除此之外,你得到的警告是非常正常的,因为xy参数被弃用反正,但你仍然可以使用它们。如果你想让警告本身消失,你应该在导入后添加tf.logging.set_verbosity(tf.logging.ERROR)tf

编辑:另外,当我想到它时,你如何为训练集设置标签而不是测试集?你应该总是分裂的训练数据和标签,这样大部分是受过训练的,但有一些是正好保持,所以你可以使用它在evaluate

+0

因为我在Kaggle,那里有没有标签做练习为测试集。 – octavian

+0

然后,我强烈建议将训练数据分成90%的训练和10%的训练数据保持未训练以用于“评估”(百分比取决于您实际拥有多少数据)。仅在'predict'中使用该测试数据并解析返回的结果。 –