学习记录——pytorch里的gather方法

torch.gather(input, dim, index)

tt=torch.tensor([[1, 2], [3, 4]]),

index=torch.tensor([[0, 0], [1, 0]])

tt和index可写为

tensor([[1, 2],
        [3, 4]])

tensor([[0, 0],
        [1, 0]])

1.当dim=1时

torch.gather(tt, 1, index)

tensor([[1, 1],
        [4, 3]])

index有两个维度,每个维度里又有两个维度,所以,dim=1时,横向搜索tt这个tensor(至于为什么是横向,我也不知道)

在tt的第0个维度,有[1, 2],而index为[0, 0],故在[1, 2]中选择[1, 1]

在tt的第1个维度,有[3, 4],而index为[1, 0],故在[3, 4]中选择[4, 3]

所以输出为([[1, 1], [4, 3]])

 

2.dim=0时

t.gather(tt,0,index)

tensor([[1, 2],
        [3, 2]])

此时dim=0,故纵向搜索

tt的第0个维度为[1, 3], 而index为[0, 1],故选择[1, 3](时刻记住为纵向)

tt的第1个维度为[2, 4],而index为[0, 0], 故选择[2, 2] 

所以结果为([[1, 2], [3, 2]])

——————

可能不太好理解,用笔写写就知道大概意思了

这里维度一词可能用的不太准确,见谅

你可能感兴趣的:(torch)