2017-03-10 63 views
0

的另一阵列的numpy的阵列,即:过滤使用给定两个numpy的阵列标签

images.shape: (60000, 784) # An array containing 60000 images 
labels.shape: (60000, 10) # An array of labels for each image 

labels每行包含一个1在特定索引以指示images类相关的例子。 (所以[0 0 1 0 0 0 0 0 0 0]将表明这个例子属于第2类(假设我们的类索引从0开始)

我试图有效地分离images,这样我就可以一次处理属于特定类的所有图像。显而易见的解决方案是使用一个for环(如下图),但我不知道如何过滤images使得仅返回那些具有相应labels

for i in range(0, labels.shape[1]): 
    class_images = # (?) Array containing all images that belong to class i 

顺便说一句,我还想知道是否有更有效的方法可以消除使用for循环。

回答

1

一种方法是你的标签阵列为bool,并用它来索引转换:

classes = [] 
blabels = labels.astype(bool) 
for i in range(10): 
    classes.append(images[blabels[:, i], :]) 

或者作为一个班轮使用列表理解:

classes = [images[l.astype(bool), :] for l in labels.T] 
0
_classes= [[] for x in range(10)] 
for image_index , element in enumerate(labels): 
    _classes[element.index(1)].append(image_index) 

例如_classes [0]将包含分类为class0的图像索引。

+0

如果你正在使用你numpy的可以使用 非零(元素== 1)[0] [0] 而不是element.index(1) – Pouyan