torch.gather(input, dim, index)
设
tt=torch.tensor([[1, 2], [3, 4]]),
index=torch.tensor([[0, 0], [1, 0]])
tt和index可写为
tensor([[1, 2],
[3, 4]])
tensor([[0, 0],
[1, 0]])
1.当dim=1时
torch.gather(tt, 1, index)
tensor([[1, 1],
[4, 3]])
index有两个维度,每个维度里又有两个维度,所以,dim=1时,横向搜索tt这个tensor(至于为什么是横向,我也不知道)
在tt的第0个维度,有[1, 2],而index为[0, 0],故在[1, 2]中选择[1, 1]
在tt的第1个维度,有[3, 4],而index为[1, 0],故在[3, 4]中选择[4, 3]
所以输出为([[1, 1], [4, 3]])
2.dim=0时
t.gather(tt,0,index)
tensor([[1, 2],
[3, 2]])
此时dim=0,故纵向搜索
tt的第0个维度为[1, 3], 而index为[0, 1],故选择[1, 3](时刻记住为纵向)
tt的第1个维度为[2, 4],而index为[0, 0], 故选择[2, 2]
所以结果为([[1, 2], [3, 2]])
——————
可能不太好理解,用笔写写就知道大概意思了
这里维度一词可能用的不太准确,见谅