0
我试图在pytorch中使用gather函数,但无法理解dim
参数的作用。收集函数中参数尺寸的影响
代码:
t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
输出:
1 2
3 2
[torch.FloatTensor of size 2x2]
尺寸设置为1:
print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
输出变为:
1 1
4 3
[torch.FloatTensor of size 2x2]
如何,gather
函数实际工作?