我有标签的NumPy的数组:IndexError索引的二维数组与一维数组(NumPy的)
labels = np.ndarray(10000, dtype=np.float32)
在数组中的元素看起来像:
print(labels[1:5])
Output: [ 9. 9. 4. 1.]
我想将它们转换成一个热编码的标签,我用下面的代码:
one_hot_labels = np.eye(10)[labels]
我得到以下错误:
IndexError Traceback (most recent call last)
<ipython-input-21-dccf85afc031> in <module>()
1
----> 2 s=np.eye(10)[labels]
IndexError: arrays used as indices must be of integer (or boolean) type
我该如何解决这个问题?
你确定标签和火车标签是一样的吗? –
你需要使用整数值作为索引:'one_hot_labels = np.eye(10)[labels.astype(int)]' – JohanL
@JohanL谢谢。它的工作原理 – Jayanth