2017-06-15 1036 views
0

我目前正在通过实施深度转网络来研究kaggle上的cats vs dogs分类任务。下面的代码行用于数据预处理:如何在python中为自定义数据实现next_batch()函数

def label_img(img): 
    word_label = img.split('.')[-3] 
    if word_label == 'cat': return [1,0] 
    elif word_label == 'dog': return [0,1] 

def create_train_data(): 
    training_data = [] 
    for img in tqdm(os.listdir(TRAIN_DIR)): 
     label = label_img(img) 
     path = os.path.join(TRAIN_DIR,img) 
     img = cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),IMG_SIZE,IMG_SIZE)) 
     training_data.append([np.array(img),np.array(label)]) 

    shuffle(training_data) 
    return training_data 

train_data = create_train_data() 

X_train = np.array([i[0] for i in train_data]).reshape(-1, IMG_SIZE,IMG_SIZE,1) 
Y_train =np.asarray([i[1] for i in train_data]) 

我想要实现复制在tensorflow深MNIST教程

batch = mnist.train.next_batch(100) 

回答

0

code提供了以下功能的功能就是一个很好的例子拿出生成批处理的功能。

简单说明,你只需要拿出两个数组的x_train和y_train喜欢:

batch_inputs = np.ndarray(shape=(batch_size), dtype=np.int32) 
    batch_labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 

并设置列车数据,如:

batch_inpouts[i] = ... 
    batch_labels[i, 0] = ... 

最后通过数据设置为会话:

_, loss_val = session.run([optimizer, loss], feed_dict={train_inputs: batch_inputs, train_labels:batch_labels}) 
+0

请试试看。谢谢你的时间。 –

2

除了生成一个批处理,您可能还想随机重新安排数据每批次。

EPOCH = 100 
BATCH_SIZE = 128 
TRAIN_DATASIZE,_,_,_ = X_train.shape 
PERIOD = TRAIN_DATASIZE/BATCH_SIZE #Number of iterations for each epoch 

for e in range(EPOCH): 
    idxs = numpy.random.permutation(TRAIN_DATASIZE) #shuffled ordering 
    X_random = X_train[idxs] 
    Y_random = Y_train[idxs] 
    for i in range(PERIOD): 
     batch_X = X_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE] 
     batch_Y = Y_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE] 
     sess.run(train,feed_dict = {X: batch_X, Y:batch_Y}) 
+0

非常感谢。最后,我可以正确地训练我的网络。 –

+0

你能否启发我tensorflow的next_batch()返回什么?它是指定批次大小的训练集中的随机数据集合吗?如果是这样,它确保不重复? @Joshua Lim –

+0

next_batch()是一个专门针对tensorflow提供的MNIST教程的函数。它的工作原理是在开始时随机化训练图像和标签对,并在每次调用函数时选择每个后续100张图像。一旦达到最后,图像标签对就会再次被随机化,并且重复该过程。整个数据集只有在使用所有可用对时才会重新洗牌并重复。 –

相关问题