2017-04-23 111 views
0

我可能在这里做了一些愚蠢的事情,但我不确定为什么会出现这个错误。tf.train.Features TypeError:不允许位置参数

此代码:

example = tf.train.Example(features=tf.train.Features(feature={ 
     'image/height': _int64_feature(FLAGS.img_height), 
     'image/width': _int64_feature(FLAGS.img_width), 
     'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)), 
     'image/channels': _int64_feature(channels), 
     'image/format': _bytes_feature(tf.compat.as_bytes(image_format)), 
     'image/label': _bytes_feature(label_img_buffer), 
     'image/label_path': _bytes_feature(tf.compat.as_bytes(os.path.basename(lbl_path))), 
     'image/fn_0': _bytes_feature(tf.compat.as_bytes(os.path.basename(ex_paths[0]))), 
     'image/encoded_0': _bytes_feature(tf.compat.as_bytes(ex_image_buffers[0])), 
     'image/fn_1': _bytes_feature(tf.compat.as_bytes(os.path.basename(ex_paths[1]))), 
     'image/encoded_1': _bytes_feature(tf.compat.as_bytes(ex_image_buffers[1])), 
     'image/fn_2': _bytes_feature(tf.compat.as_bytes(os.path.basename(ex_paths[2]))), 
     'image/encoded_2': _bytes_feature(tf.compat.as_bytes(ex_image_buffers[2]))})) 
return example 

但这代码不起作用(抛出类型错误的文章标题):

feature_dict={ 
     'image/height': _int64_feature(FLAGS.img_height), 
     'image/width': _int64_feature(FLAGS.img_width), 
     'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)), 
     'image/channels': _int64_feature(channels), 
     'image/format': _bytes_feature(tf.compat.as_bytes(image_format)), 
     'image/label': _bytes_feature(label_img_buffer), 
     'image/label_path': _bytes_feature(tf.compat.as_bytes(os.path.basename(lbl_path))), 
     } 

    for idx, image in sorted(ex_image_buffers.iteritems()): 
    img_key = 'image/encoded_' + str(idx) 
    fn_key = 'image/fn_' + str(idx) 
    feature_dict[img_key] = _bytes_feature(tf.compat.as_bytes(image)) 
    feature_dict[fn_key] = _bytes_feature(tf.compat.as_bytes(os.path.basename(ex_paths[idx]))) 

    example = tf.train.Example(features=tf.train.Features(feature_dict)) 
    return example 

ex_image_buffers是一个列表。

据我所知,tf.train.Features以一个字典作为参数,并且我在第一个例子和第二个例子中组装了相同的字典(我认为)。第二个允许我根据其他代码调整字典,所以我宁愿避免对不同字段进行硬编码。

想法?谢谢你的帮助!

回答

1

是的,我认为你有一个愚蠢的错误。尝试

example = tf.train.Example(features=tf.train.Features(feature=feature_dict))

为错误状态,tf.train.Features需要你按关键字/参数对通过。您需要添加关键字feature,就像您在第一个示例中所做的那样。

+0

谢谢,这工作。感谢帮助。为什么这个要求在这里? – user3325669

+0

我没有设计API,所以我不能说。您正在使用该类的构造函数,文档(https://www.tensorflow.org/api_docs/python/tf/train/Features)仅显示'** kwargs'作为输入。 –

+0

@JimParker链接到您提供的文档没有任何内容。有没有办法知道这些方法做什么,因为文档主要是所有方法的列表。 – deadcode