2016-12-06 176 views
11

tensorflow MNIST tutorialmnist.train.next_batch(100)功能非常方便。我现在试图自己实现一个简单的分类。我有我的训练数据在一个numpy数组中。我怎么能为我自己的数据实现一个类似的功能来给我下一批?如何实现tensorflow的next_batch为自己的数据

sess = tf.InteractiveSession() 
tf.global_variables_initializer().run() 
Xtr, Ytr = loadData() 
for it in range(1000): 
    batch_x = Xtr.next_batch(100) 
    batch_y = Ytr.next_batch(100) 

回答

10

您发布的链接说:“我们得到了一个‘批’百个随机数据点,从我们的训练集”。在我的例子中,我使用了一个全局函数(不像你的例子中的方法),所以在语法上会有所不同。

在我的函数中,您需要传递想要的样本数和数据数组。

下面是正确的代码,从而确保样品有正确的标签:

import numpy as np 

def next_batch(num, data, labels): 
    ''' 
    Return a total of `num` random samples and labels. 
    ''' 
    idx = np.arange(0 , len(data)) 
    np.random.shuffle(idx) 
    idx = idx[:num] 
    data_shuffle = [data[ i] for i in idx] 
    labels_shuffle = [labels[ i] for i in idx] 

    return np.asarray(data_shuffle), np.asarray(labels_shuffle) 

Xtr, Ytr = np.arange(0, 10), np.arange(0, 100).reshape(10, 10) 
print(Xtr) 
print(Ytr) 

Xtr, Ytr = next_batch(5, Xtr, Ytr) 
print('\n5 random samples') 
print(Xtr) 
print(Ytr) 

及示范运行:

[0 1 2 3 4 5 6 7 8 9] 
[[ 0 1 2 3 4 5 6 7 8 9] 
[10 11 12 13 14 15 16 17 18 19] 
[20 21 22 23 24 25 26 27 28 29] 
[30 31 32 33 34 35 36 37 38 39] 
[40 41 42 43 44 45 46 47 48 49] 
[50 51 52 53 54 55 56 57 58 59] 
[60 61 62 63 64 65 66 67 68 69] 
[70 71 72 73 74 75 76 77 78 79] 
[80 81 82 83 84 85 86 87 88 89] 
[90 91 92 93 94 95 96 97 98 99]] 

5 random samples 
[9 1 5 6 7] 
[[90 91 92 93 94 95 96 97 98 99] 
[10 11 12 13 14 15 16 17 18 19] 
[50 51 52 53 54 55 56 57 58 59] 
[60 61 62 63 64 65 66 67 68 69] 
[70 71 72 73 74 75 76 77 78 79]] 
+2

我相信这不会按照用户的期望工作。输入Xtr和输出Ytr之间有1:1的相关性。随机化发生在每个人身上。相反,应该挑选一组随机值,​​然后应用于两组。 –

+1

谢谢,我更新了我的帖子。 – edo

+2

@edo您可以使用'data [idx]'来代替[id [i]中的数据[i]],这样您就不会再从榜单跳转到列表并再次回到ndarrays。 –

6

为了洗牌和采样每个小批量,状态是否在当前时代已经选择了一个样本也应该考虑。这是一个使用上述答案中的数据的实现。

import numpy as np 

class Dataset: 

def __init__(self,data): 
    self._index_in_epoch = 0 
    self._epochs_completed = 0 
    self._data = data 
    self._num_examples = data.shape[0] 
    pass 


@property 
def data(self): 
    return self._data 

def next_batch(self,batch_size,shuffle = True): 
    start = self._index_in_epoch 
    if start == 0 and self._epochs_completed == 0: 
     idx = np.arange(0, self._num_examples) # get all possible indexes 
     np.random.shuffle(idx) # shuffle indexe 
     self._data = self.data[idx] # get list of `num` random samples 

    # go to the next batch 
    if start + batch_size > self._num_examples: 
     self._epochs_completed += 1 
     rest_num_examples = self._num_examples - start 
     data_rest_part = self.data[start:self._num_examples] 
     idx0 = np.arange(0, self._num_examples) # get all possible indexes 
     np.random.shuffle(idx0) # shuffle indexes 
     self._data = self.data[idx0] # get list of `num` random samples 

     start = 0 
     self._index_in_epoch = batch_size - rest_num_examples #avoid the case where the #sample != integar times of batch_size 
     end = self._index_in_epoch 
     data_new_part = self._data[start:end] 
     return np.concatenate((data_rest_part, data_new_part), axis=0) 
    else: 
     self._index_in_epoch += batch_size 
     end = self._index_in_epoch 
     return self._data[start:end] 

dataset = Dataset(np.arange(0, 10)) 
for i in range(10): 
    print(dataset.next_batch(5)) 

输出为:

[2 8 6 3 4] 
[1 5 9 0 7] 
[1 7 3 0 8] 
[2 6 5 9 4] 
[1 0 4 8 3] 
[7 6 2 9 5] 
[9 5 4 6 2] 
[0 1 8 7 3] 
[9 7 8 1 6] 
[3 5 2 4 0] 

在第一和第二(第三和第四,...)小批量对应于一个整个时代..

1

我使用Anaconda和Jupyter 。 在Jupyter如果您运行?mnist你: File: c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py Docstring: Datasets(train, validation, test)

在文件夹datesets你会发现mnist.py它包含了所有的方法,包括next_batch

1

如果你不希望得到的形状不匹配的错误在你tensorflow会话中运行 再使用的,而不是在第一个解决方案上面提供的功能,下面的函数(https://stackoverflow.com/a/40995666/7748451) -

def next_batch(num, data, labels): 

    ''' 
    Return a total of `num` random samples and labels. 
    ''' 
    idx = np.arange(0 , len(data)) 
    np.random.shuffle(idx) 
    idx = idx[:num] 
    data_shuffle = data[idx] 
    labels_shuffle = labels[idx] 
    labels_shuffle = np.asarray(labels_shuffle.values.reshape(len(labels_shuffle), 1)) 

    return data_shuffle, labels_shuffle 
相关问题