2017-06-20 108 views
0

我试图执行上,我使用“AlexNet细化和微调与TensorFlow” https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html也不会在Java API运行Tensorflow预测

我保存在使用Python tf.saved_model.builder.SavedModelBuilder模型训练的模型预测,并加载Java中的模型使用SavedModelBundle.load。 代码的主要部分是:

SavedModelBundle smb = SavedModelBundle.load(path, "serve"); 
    Session s = smb.session(); 
    byte[] imageBytes = readAllBytesOrExit(Paths.get(path)); 
    Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes); 
    Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0); 
    final long[] rshape = result.shape(); 
    if (result.numDimensions() != 2 || rshape[0] != 1) { 
     throw new RuntimeException(
       String.format(
         "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", 
         Arrays.toString(rshape))); 
    } 
    int nlabels = (int) rshape[1]; 
    float [] a = result.copyTo(new float[1][nlabels])[0];` 

我得到这个异常:在线程

异常“主” java.lang.IllegalArgumentException异常:您必须养活一个值占位张量“ Placeholder_1'用dtype float [[Node:Placeholder_1 = Placeholder_output_shapes = [[]],dtype = DT_FLOAT,shape = [],_device =“/ job:localhost/replica:0/task:0/cpu:0”]]

我看到上面的代码为某些人工作,我无法弄清楚这里缺少的东西。 请注意,该网络熟悉节点“input_tensor”和“fc8/fc8”,因为它没有说它不知道它们。

回答

1

从错误消息中可以看出,您使用的模型期望得到另一个值(图中的节点名称为Placeholder_1,预期类型为浮点标量张量)。

看起来你已经定制了你的模型(而不是跟随你链接到逐字的文章)。也就是说,文章显示需要喂食的多个占位符,一个用于图像,另一个用于控制脱落。在文章中定义为:

keep_prob = tf.placeholder(tf.float32) 

此占位符的值需要提供。如果您正在进行推理,那么您想将keep_prob设置为1.0。类似于:

Tensor keep_prob = Tensor.create(1.0f); 
Tensor result = s.runner() 
    .feed("input_tensor", image) 
    .feed("Placeholder_1", keep_prob) 
    .fetch("fc8/fc8") 
    .run() 
    .get(0); 

希望有所帮助。