2017-10-18 41 views
0

我试图在pytorch中使用gather函数,但无法理解dim参数的作用。收集函数中参数尺寸的影响

代码:

t = torch.Tensor([[1,2],[3,4]]) 
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]]))) 

输出:

1 2 
3 2 
[torch.FloatTensor of size 2x2] 

尺寸设置为1:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))) 

输出变为:

1 1 
4 3 
[torch.FloatTensor of size 2x2] 

如何,gather函数实际工作?

回答

2

我意识到收集功能是如何工作的。

​​

由于dimension为零,因此,输出将是:

| t[index[0, 0] 0] t[index[0, 1] 1] | 
| t[index[1, 0] 0] t[index[1, 1] 1] | 

如果dimension被设置为一,则输出将变为:

| t[0 index[0, 0]] t[0 index[0, 1]] | 
| t[1 index[1, 0]] t[1 index[1, 1]] | 

所以公式是:

For a 3-D tensor the output is specified by: 

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

参考:http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather