2017-05-08 183 views
0

我有一个运行良好的深层神经网络。然而,加入以下代码提起提早终止导致错误:使用SessionRunHook()或验证监视器在Tensorflow中提前停止

validation_metrics = { 
"accuracy": 
    tf.contrib.learn.MetricSpec(
     metric_fn=tf.contrib.metrics.streaming_accuracy, 
     prediction_key=tf.contrib.learn.prediction_key.PredictionKey. 
     CLASSES)} 


validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
x=X_test, y=y_test, early_stopping_rounds=50, metrics=validation_metrics) 

输出:

prediction_key=tf.contrib.learn.prediction_key.PredictionKey.CLASSES)} 
AttributeError: module 'tensorflow.contrib.learn' has no attribute 'prediction_key' 
+0

您可以添加更多有关哪里出错的详细信息,并查看是否可以减少重现错误所需的代码示例的大小?这将帮助我们更好地回答这个问题。 –

+0

谢谢 - 减少原始问题中的代码。完整的代码要点[这里](https://gist.github.com/KT12/7b081dfb776e8b0fde4d1275b980cc70) – KT12

+0

您使用的是tensorflow和python的版本? – Wontonimo

回答

0

你可以尝试以下方法:

prediction_key=tf.contrib.learn.PredictionKey.CLASSES