2017-08-04 106 views
0

我知道这是一个常见错误,但我无法理解此问题。这里是我的代码:Tensorflow:您必须为dtype float提供占位符张量'input_image'的值

def convert_image(url): 

    checkpoint_file = './vgg_16.ckpt' 

    input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image') 
    scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor) 
    scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5) 
    scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0) 

    #Load the model 
    sess = tf.Session() 
    arg_scope = vgg_arg_scope() 
    with slim.arg_scope(arg_scope): 
     logits, end_points = vgg_16(scaled_input_tensor, is_training=False) 
    saver = tf.train.Saver() 
    saver.restore(sess, checkpoint_file) 

    response = requests.get(url) 
    img = Image.open(BytesIO(response.content)) 
    im = np.array(img, dtype='float32') 
    im = im.reshape(-1,224,224,3) 

    features = sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}) 
    sess.close() 
    return np.squeeze(features) 

正如你所看到的,我使用VGG_16预训练模型来提取fc7特征。大约50%的代码只是从URL获取图像并将其转换为224x224x3;另外50%的张量流工作得到实际的特征表示。

事情是,我第一次运行这个代码,它工作正常。但是,第二次,我得到了上述错误。当然,即使我收到这个错误,“im”也是一个float32。所以我认为这个问题与第二次运行此功能时遇到的问题有关。如果我不得不猜测,这与“保护者”的工作方式有关,但我一直无法弄清楚究竟是什么。

任何想法?

回答

1

错误很可能是由于您重新定义了input_tensor,而不是在VGG模型中使用输入占位符。您可以在将输入图像提供给网络之前对输入图像im应用转换。

此外,您为每个图像加载模型。 相反,加载模型一次,然后迭代循环内的图像列表。 类似这样的:

def convert_images(url_list): 
    # Load the TF model 
    #..... 
    # Session, etc. 

    # Now, go over the list of images one by one 
    for url in url_list: 
     image = ... # get image 
     features = session.run(...) # extract features 
+0

我同意这是更好的编码实践。实际上,我会继续重构整个代码,以便不必多次加载模型。 我希望我知道为什么我写的代码出错了。 – anon

+0

我认为,实际的问题是:您重新定义输入(input_tensor),而不是使用模型中已有的输入占位符。所以,你应该看看你的模型中的'输入',并将其与图像一起输入。 – Blackberry

相关问题