2017-04-05 98 views
0

Trying在不同的类别标签(2)上精调Tensorflow Slim VGG16网络,不包括fc8。执行时,我得到这个错误。TFSlim ValueError不能挤压dim [1],期望尺寸为1,对输入形状'vgg_16/fc8/squeezed'(op:'Squeeze')得到3':[3,3,3,2]

错误

logits, _ = vgg.vgg_16(images, num_classes=NUM_CLASSES, is_training=True) 
/models/slim/nets/vgg.py", line 178, in vgg_16 
net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 
/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 2273, in squeeze 
--- STACK TRACE OMITTED ----- 
/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl 
raise ValueError(err.message) 
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 3 for 'vgg_16/fc8/squeezed' (op: 'Squeeze') with input shapes: [3,3,3,2]. 

代码

BATCH_SIZE = 3 
NUM_CLASSES = 2 
def load_batch(): 
    filepaths, labels = read_label_file(train_labels_file) 
    images = ops.convert_to_tensor(filepaths, dtype=dtypes.string) 
    labels = ops.convert_to_tensor(labels, dtype=dtypes.int32) 
    input_queue = tf.train.slice_input_producer([images, labels],shuffle=False) 
    file_content = tf.read_file(input_queue[0]) 
    image = tf.image.decode_jpeg(file_content, channels=NUM_CHANNELS) 
    label = input_queue[1] 
    image.set_shape([387,408,3]) 
    size = tf.constant([224,224],dtype=tf.int32) 
    image = tf.image.resize_images(image,size) 
    image_batch, label_batch = tf.train.batch([image, label],batch_size=BATCH_SIZE , num_threads=1) 
    return image_batch , label_batch 

with tf.Graph().as_default(): 

    tf.logging.set_verbosity(tf.logging.INFO) 
    images,labels = load_batch() 
    with slim.arg_scope(vgg.vgg_arg_scope()): 
     logits, _ = vgg.vgg_16(images, num_classes=NUM_CLASSES, is_training=True) 
    .... 

回答

0

你能尽量只直接定义批次:

with tf.Graph().as_default(): 

    tf.logging.set_verbosity(tf.logging.INFO) 
    images = tf.randon_uniform([BATCH_SIZE, 224, 224, 3]) 
    labels = tf.randon_uniform([BATCH_SIZE], max_value=NUM_CLASES) 
    with slim.arg_scope(vgg.vgg_arg_scope()): 
     logits, _ = vgg.vgg_16(images, num_classes=NUM_CLASSES, is_training=True) 

您也可以调试张量图像和标签的形状才通过他们来vgg.vgg_16

相关问题