2015-08-14 75 views
9

我有一个整数类标签的字节张量,例如来自MNIST数据集。在火炬中如何从整数标签列表创建单热张量?

1 
7 
5 
[torch.ByteTensor of size 3] 

如何使用它来创建1热矢量的张量?

1 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 1 0 0 0 
0 0 0 0 1 0 0 0 0 0 
[torch.DoubleTensor of size 3x10] 

我知道我可以用一个循环做到这一点,但我不知道是否有任何聪明的火炬索引,将让这对我在一个单一的线。

回答

13
indices = torch.LongTensor{1,7,5}:view(-1,1) 
one_hot = torch.zeros(3, 10) 
one_hot:scatter(2, indices, 1) 

你可以找到在torch/torch7 github readmescatter的文件(在主分支)。

2

的另一种方法是从单位矩阵洗牌行:

indicies = torch.LongTensor{1,7,5} 
one_hot = torch.eye(10):index(1, indicies) 

这不是我的主意,我发现它在karpathy/char-rnn