2017-04-27 95 views
0

我似乎一直搞不清楚如何重塑数据以适应模型。我认为输入和输出数据的形状必须匹配,但我仍然迷失在如何去做这件事。如何调整从灰度到RGB的图像?

我认为我的主要问题是灰度图像和RGB图像的存储方式不同。 [1]对[255,255,255]

因此,如果:

屏幕= cv2.cvtColor(屏幕,cv2.COLOR_BGR2RGB)

改变为:

屏幕= cv2.cvtColor(屏幕, cv2.COLOR_BGR2GRAY)

一切工作正常。

有问题的代码:

# Capture Data (CUT SHORT) 
WIDTH = 160 
HEIGHT = 120 
screen = cv2.cvtColor(screen, cv2.COLOR_BGR2RGB) 
screen = cv2.resize(screen, (WIDTH, HEIGHT)) 
dataset = [] 
output = [0, 0, 0, 0] 
dataset.append([screen, output]) 
np.save("training.npy", dataset) 

# Build Model 
https://github.com/tflearn/tflearn/blob/master/examples/images/alexnet.py 

# Changed to match output. 
network = fully_connected(network, 4, activation='softmax') 

# Train Data 
WIDTH = 160 
HEIGHT = 120 
LR = 1e-3 
EPOCHS = 5 
MODEL_NAME = "HELP" 

model = alexnet(WIDTH, HEIGHT, LR) 

for i in range(EPOCHS): 
    train_data = np.load("training.npy".format(i)) 

    train = train_data[:-100] 
    test = train_data[-100:] 

    X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,1) 
    Y = [i[1] for i in train] 

    test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1) 
    test_y = [i[1] for i in test] 

    model.fit({'input': X}, {'targets': Y}, n_epoch=1, validation_set=({'input': test_x}, {'targets': test_y}), 
     snapshot_step=500, show_metric=True, run_id=MODEL_NAME) 

    model.save(MODEL_NAME) 

错误: 异常螺纹加工-3: 回溯(最近通话最后一个): 文件“C:\用户\ TF \应用程序数据\本地\程序\ Python \ Python35 \ lib \ threading.py“,第914行,在_bootstrap_inner中 self.run() 文件”C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ threading.py“,第862行,运行 self._target(* self._args,** self._kwargs) 文件“C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ site-packages \ tflearn \ data_flow的.py “行187,在fill_feed_dict_queue data = self.retrieve_data(batch_ids) 文件”C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ site-packages \ tflearn \ data_flow.py“,行222,in retrieve_data utils.slice_array(self.feed_dict [key],batch_ids) 文件“C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ site-packages \ tflearn \ utils.py” ,线187,在slice_array 返回X [开始]

IndexError:指数2936是出界对轴0的大小为1900

回答

0

罗伯特Kirchgessner博士: 您的输入数据集三个通道。

np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,3) 

在alexnet:

​​