2016-09-28 83 views
0

我开始使用张贴的教程开始使用TensorFlow。tf.contrib.learn快速入门:修正float64警告

我有在Fedora 23(二十三)上运行的Linux CPU python2.7版本0.10.0。

我想按照下面的代码tf.contrib.learn快速入门教程。

https://www.tensorflow.org/versions/r0.10/tutorials/tflearn/index.html#tf-contrib-learn-quickstart

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import tensorflow as tf 
import numpy as np 

# Data sets 
IRIS_TRAINING = "IRIS_data/iris_training.csv" 
IRIS_TEST = "IRIS_data/iris_test.csv" 

# Load datasets. 
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING, 
                target_dtype=np.int) 
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST, 
               target_dtype=np.int) 

# Specify that all features have real-value data 
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] 

# Build 3 layer DNN with 10, 20, 10 units respectively. 
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, 
             hidden_units=[10, 20, 10], 
             n_classes=3, 
             model_dir="/tmp/iris_model") 

# Fit model. 
classifier.fit(x=training_set.data, 
      y=training_set.target, 
      steps=2000) 

# Evaluate accuracy. 
accuracy_score = classifier.evaluate(x=test_set.data, 
           y=test_set.target)["accuracy"] 
print('Accuracy: {0:f}'.format(accuracy_score)) 

# Classify two new flower samples. 
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) 
y = classifier.predict(new_samples) 
print('Predictions: {}'.format(str(y))) 

的代码执行,但给人float64警告。正如:

$ python confErr.py 
WARNING:tensorflow:load_csv (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed after 2016-09-15. 
Instructions for updating: 
Please use load_csv_{with|without}_header instead. 
WARNING:tensorflow:load_csv (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed after 2016-09-15. 
Instructions for updating: 
Please use load_csv_{with|without}_header instead. 
WARNING:tensorflow:Using default config. 
WARNING:tensorflow:float64 is not supported by many models, consider casting to float32. 
WARNING:tensorflow:Setting feature info to TensorSignature(dtype=tf.float64, shape=TensorShape([Dimension(None), Dimension(4)]), is_sparse=False) 
WARNING:tensorflow:Setting targets info to TensorSignature(dtype=tf.int64, shape=TensorShape([Dimension(None)]), is_sparse=False) 
WARNING:tensorflow:float64 is not supported by many models, consider casting to float32. 
WARNING:tensorflow:Given features: Tensor("input:0", shape=(?, 4), dtype=float64), required signatures: TensorSignature(dtype=tf.float64, shape=TensorShape([Dimension(None), Dimension(4)]), is_sparse=False). 
WARNING:tensorflow:Given targets: Tensor("output:0", shape=(?,), dtype=int64), required signatures: TensorSignature(dtype=tf.int64, shape=TensorShape([Dimension(None)]), is_sparse=False). 
Accuracy: 0.966667 
WARNING:tensorflow:float64 is not supported by many models, consider casting to float32. 
Predictions: [1 1] 

注意:用'load_csv_with_header()'替换'load_csv()'会产生正确的预测。但float64警告仍然存在。

我试着为training_set,test_set和new_samples明确列出dtype(np.int32; np.float32; tf.int32; tf.float32)。

我也试过“铸造” feature_columns为:

feature_columns = tf.cast(feature_columns, tf.float32) 

与float64已知发展问题的问题,但我不知道是否有一些解决方法吗?

回答

0

我通过Git-hub从开发团队处收到答复。

嗨@qweelar,float64警告是由于load_csv_with_header函数在提交b6813bd中已修复的错误。此修复程序不在TensorFlow发行版0.10中,但应在下一个发行版中。

与此同时,出于tf.contrib.learn快速入门的目的,您可以安全地忽略float64警告。

(旁注:在其他弃用警告而言,我将更新教程代码中使用load_csv_with_header,和当那是在地方会更新这个问题。)