2017-03-15 64 views
1

我正尝试使用新的Java API从磁盘读取模型。如何在Tensorflow的Java API中使用`saver.save`加载模型保存

The one example要使用Tensorflow的Java API显示如何读取具有图形定义和参数权重的.pb模型文件。

在Python方面,Tensorflow建议使用Saver对象将模型保存到磁盘。它会创建一个.meta文件,该文件具有该定义并且具有.data文件的权重。在Python中,我使用new_saver=tf.train.import_meta_graph(var_filename) new_saver.restore(sess, model_filename)从磁盘读取模型。

如何在Java API中执行此操作?

回答

0

SavedModelBundle类可能是你正在寻找的。特别是,SavedModelBundle.load()将返回一个Session,您可以使用它来执行保存的模型。

请注意,此功能最近添加到Java API中,因此它不存在于二进制发行版中,因此您必须在发布TensorFlow 1.1之前build the Java API from source

+0

很好,谢谢。我目前的解决方案是使用'freeze_graph'保存图形def和权重,然后用Java读取。这个班看起来很有前途我会等到官方发布的代码尝试它, –

0

我正在做类似的事情,使用python界面在hadoop集群上训练模型,并使用模型和学习参数在java中进行预测。

用法是在Java端很简单:

SavedModelBundle load = SavedModelBundle.load(modelDir, "serve"); 
     float[][] resultArray; 
     try (Graph g = load.graph()) { 
      try (Session s = load.session(); 
       Tensor result = s.runner().feed("data", data).fetch("prediction").run().get(0)) { 
       resultArray = result.copyTo(new float[10][1]); 
      } 
     } 
     load.close(); 
     return resultArray; 

要获取的名称Feed和获取的操作可以打印签名,并使用输入和输出值名称。

print(prediction_signature) 

https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py#L119

相关问题