2017-06-13 132 views
0

我正在使用张量流的imageNet训练模型来分类图像的多个类别。ValueError:GraphDef不能大于2GB

我编辑的脚本classify.py作为

import tensorflow as tf 
import sys 
import glob 
import os 
import pandas as pd 

# Disable tensorflow compilation warnings 
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
import tensorflow as tf 

test_path = '/Users/kaustubhmundra/Desktop/Multi-Class Classifier/test' 

classes = ['room','reception','washroom','facade'] 

result = pd.DataFrame(columns = ['facade','washroom','room','reception']) 

def predict(image_path): 
    #image_path = sys.argv[1] 

    # Read the image_data 
    image_data = tf.gfile.FastGFile(image_path, 'rb').read() 

    # Loads label file, strips off carriage return 
    label_lines = [line.rstrip() for line 
         in tf.gfile.GFile("tf_files/retrained_labels.txt")] 

    # Unpersists graph from file 
    with tf.gfile.FastGFile("tf_files/retrained_graph.pb", 'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     _ = tf.import_graph_def(graph_def, name='') 

    with tf.Session() as sess: 
     # Feed the image_data as input to the graph and get first prediction 
     softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') 

     predictions = sess.run(softmax_tensor, \ 
       {'DecodeJpeg/contents:0': image_data}) 

     # print(predictions) 

     pred = pd.DataFrame(predictions,columns = ['facade','washroom','room','reception']) 

     # print(pred) 

     global result 

     result = result.append(pred) 

     # print(result) 

     # Sort to show labels of first prediction in order of confidence 
     top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] 

     for node_id in top_k: 
      human_string = label_lines[node_id] 
      score = predictions[0][node_id] 
      print('%s (score = %.5f)' % (human_string, score)) 



path = os.path.join(test_path, '*') 
files = sorted(glob.glob(path)) 

i=1 

for fl in files: 
    print(i) 
    i = i + 1 
    predict(fl) 

result.to_csv('predictions.csv') 

虽然我用它来预测上的图像,它完美的作品,直到24倍的图像,但随后显示了一个错误:

File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2154, in _as_graph_def raise ValueError("GraphDef cannot be larger than 2GB.") ValueError: GraphDef cannot be larger than 2GB.

我如何解决这个问题?

回答

0

您每次调用predict()时都会导入图表,因此您正在累积非常大的默认graphdef。您应该更改代码,以便仅在预测函数之外加载一次图形(“文件中的#Unpersists图形”部分)。这也可以大大加快你的代码。

+0

非常感谢! 这个工程令人惊讶。这很简单,我不知道为什么它没有打我。 :) –

相关问题