2017-07-19 145 views
0

我有标签的NumPy的数组:IndexError索引的二维数组与一维数组(NumPy的)

labels = np.ndarray(10000, dtype=np.float32) 

在数组中的元素看起来像:

print(labels[1:5]) 
Output: [ 9. 9. 4. 1.] 

我想将它们转换成一个热编码的标签,我用下面的代码:

one_hot_labels = np.eye(10)[labels] 

我得到以下错误:

IndexError  Traceback (most recent call last) 
<ipython-input-21-dccf85afc031> in <module>() 
    1 
----> 2 s=np.eye(10)[labels] 

IndexError: arrays used as indices must be of integer (or boolean) type 

我该如何解决这个问题?

+0

你确定标签和火车标签是一样的吗? –

+2

你需要使用整数值作为索引:'one_hot_labels = np.eye(10)[labels.astype(int)]' – JohanL

+0

@JohanL谢谢。它的工作原理 – Jayanth

回答

2

您已将标签定义为np.float32。如果要将它们用作数组或矩阵的索引,则它们必须是整数。要转换np.float32使用.astype(int)

one_hot_labels=np.eye(10)[labels.astype(int)] 

或整数直接定义标签:

labels=np.ndarray(10000,dtype=int) 
+1

@Jayanth如果他回答了你的问题,请接受答案。 :) – SH7890

1

如果labelsfloat,你不希望改变其dtype,你可以简单地使用MultiLabelBinarizer。这段代码应该完成这项工作:

from sklearn.preprocessing import MultiLabelBinarizer 

mlb = MultiLabelBinarizer() 
one_hot_labels = mlb.fit_transform(labels[:, None])