2017-02-22 98 views
1

我试图使用keras model.fit_generator()来拟合模型,下面是我的发电机的定义:keras model fit_generator ValueError:检查模型目标时出现错误:期望cropping2d_4具有4个维度,但获取了具有形状的数组(32,1)

from sklearn.utils import shuffle 
IMG_PATH_PREFIX = "./data/IMG/" 
def generator(samples, batch_size=64): 
    num_samples = len(samples) 
    while 1: # Loop forever so the generator never terminates 
     shuffle(samples) 
     for offset in range(0, num_samples, batch_size): 
      batch_samples = samples[offset:offset+batch_size] 

      images = [] 
      angles = [] 
      for batch_sample in batch_samples: 
       name = IMG_PATH_PREFIX + batch_sample[0].split('/')[-1] 

       center_image = cv2.imread(name) 
       center_angle = float(batch_sample[3])     

       images.append(center_image) 
       angles.append(center_angle) 

     X_train = np.array(images) 
     y_train = np.array(angles) 

     #X_train = np.expand_dims(X_train, axis=0) 
     #y_train = np.expand_dims(y_train, axis=1) 
     print("X_train shape: ", X_train.shape, " y_train shape:", y_train.shape) 
     #print("X train: ", X_train) 
     yield X_train, y_train 

train_generator = generator(train_samples, batch_size = 32) 
validation_generator = generator(validation_samples, batch_size = 32) 

在这里,输出形状是: X_train形状:(32,160,320,3)y_train形状:(32,)

的模型拟合代码是:

model = Sequential() 
#cropping layer 
model.add(Cropping2D(cropping=((50,20), (1,1)), input_shape=(160,320,3), dim_ordering='tf')) 
model.compile(loss = "mse", optimizer="adam") 
model.fit_generator(train_generator, samples_per_epoch= len(train_samples), validation_data=validation_generator, nb_val_samples=len(validation_samples), nb_epoch=3) 

然后我得到的错误信息:

ValueError异常:错误检查时模型的目标:预计cropping2d_6有4种尺寸,但得到了与形状阵​​列(32 1)

有人能帮助让我知道什么是问题?

回答

1

这里最大的问题是:你知道你在做什么? 1)如果读取的是here,则输入是4D张量,输出也是4D张量。你的目标是形状的二维张量(batch_size,1)。所以当然,当keras试图计算具有3D(没有批量维度)的输出和具有1D(无批量维度)的目标之间的误差时,它无法从中理解。产出和目标必须具有相同的尺寸。

2)你知道cropping2D实际上在做什么吗?它正在裁剪你的图像...因此在裁剪尺寸的开始和结束时删除值。在你的情况下,你正在输出形状的图像(90,218,3)。这不是一个预测,在这个层面上训练没有重量,所以没有理由适合“模型”。你的模型只是裁剪图像。没有需要的培训。

相关问题