我试图执行上,我使用“AlexNet细化和微调与TensorFlow” https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html也不会在Java API运行Tensorflow预测
我保存在使用Python tf.saved_model.builder.SavedModelBuilder
模型训练的模型预测,并加载Java中的模型使用SavedModelBundle.load
。 代码的主要部分是:
SavedModelBundle smb = SavedModelBundle.load(path, "serve");
Session s = smb.session();
byte[] imageBytes = readAllBytesOrExit(Paths.get(path));
Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0);
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
float [] a = result.copyTo(new float[1][nlabels])[0];`
我得到这个异常:在线程
异常“主” java.lang.IllegalArgumentException异常:您必须养活一个值占位张量“ Placeholder_1'用dtype float [[Node:Placeholder_1 = Placeholder_output_shapes = [[]],dtype = DT_FLOAT,shape = [],_device =“/ job:localhost/replica:0/task:0/cpu:0”]]
我看到上面的代码为某些人工作,我无法弄清楚这里缺少的东西。 请注意,该网络熟悉节点“input_tensor”和“fc8/fc8”,因为它没有说它不知道它们。