o u t [ i ] [ j ] = i n p u t [ i n d e x [ i ] [ j ] ] [ j ] , i f d i m = = 0 out [i] [j] = input [index [i] [j] ] [j], if{\ }dim == 0 out[i][j]=input[index[i][j]][j],if dim==0
o u t [ i ] [ j ] = i n p u t [ i ] [ i n d e x [ i ] [ j ] ] , i f d i m = = 1 out [i] [j] = input [i] [index [i] [j] ], if{\ }dim == 1 out[i][j]=input[i][index[i][j]],if dim==1
从行的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个 2 ∗ 3 2*3 2∗3的维度的张量,index为 [ 0 , 1 , 1 ] [0, 1, 1] [0,1,1],取 d i m = 0 dim=0 dim=0,根据规则,外层循环为变量 i i i,内层循环为变量 j j j,且 i i n r a n g e ( 0 , 2 ) ; j i n r a n g e ( 0 , 3 ) i {\ }in{\ } range(0, 2); j{\ } in{\ } range(0, 3) i in range(0,2);j in range(0,3)。
代入 i = 0 , j = 1 i=0,{\ }j=1 i=0, j=1,得到:
o u t [ 0 ] [ 1 ] = i n p u t [ i n d e x [ 0 ] [ 1 ] ] [ 1 ] out[0][1]=input[index[0][1]][1] out[0][1]=input[index[0][1]][1]
o u t [ 0 ] [ 1 ] = i n p u t [ 1 ] [ 1 ] out[0][1]=input[1][1] out[0][1]=input[1][1]
即:该输出元素为输入的 2 ∗ 3 2*3 2∗3维度张量的第1行第1列元素。且该元素在输出张量中处在第0行第1列的位置。
如下表所示:
0 | 1 | 2 | |
---|---|---|---|
0 | |||
1 | this element |
其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。
从列的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个 2 ∗ 3 2*3 2∗3的维度的张量,index为 [ [ 0 , 1 , 1 ] , [ 1 , 1 , 1 ] ] [[0, 1, 1],[1, 1, 1]] [[0,1,1],[1,1,1]],取 d i m = 1 dim=1 dim=1,根据规则,外层循环为变量 i i i,内层循环为变量 j j j,且 i i n r a n g e ( 0 , 2 ) ; j i n r a n g e ( 0 , 3 ) i {\ }in{\ } range(0, 2); j{\ } in{\ } range(0, 3) i in range(0,2);j in range(0,3)。
代入 i = 1 , j = 1 i=1,{\ }j=1 i=1, j=1,得到:
o u t [ 1 ] [ 1 ] = i n p u t [ 1 ] [ i n d e x [ 1 ] [ 1 ] ] out[1][1]=input[1][index[1][1]] out[1][1]=input[1][index[1][1]]
o u t [ 1 ] [ 1 ] = i n p u t [ 1 ] [ 1 ] out[1][1]=input[1][1] out[1][1]=input[1][1]
即:该输出元素为输入的 2 ∗ 3 2*3 2∗3维度张量的第1行第1列元素,且该元素在输出张量中处在第1行第1列的位置。
如下表所示:
0 | 1 | 2 | |
---|---|---|---|
0 | |||
1 | this element |
其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。
声明:下述代码仅针对原理部分编写,距离函数内部真实情况仍存在较大差距,且下述代码的严谨性不够,故仅供理解gather的核心规则。
def gather(input, dim, index):
# 这里的dim要求取0或1
out = []
m = input.size()[0] # size函数是torch的方法
n = input.size()[1]
for i in range(m):
for j in range(n):
if dim == 0:
out [i] [j] = input [index [i] [j] ] [j]
if dim == 1:
out [i] [j] = input [i] [index [i] [j] ]
return out
与上一篇博文内容相同,这里再次展示一遍。
import torch
# 设置一个随机种子
torch.manual_seed(100)
# 生成一个形状为2*3的矩阵
x = torch.randn(2, 3)
print(x)
# 获取指定索引对应的值
index = torch.LongTensor([[0, 1, 1]])
print(torch.gather(x, 0, index))
index = torch.LongTensor([[0, 1, 1], [1, 1, 1]])
a = torch.gather(x, 1, index)
print(a)
吴茂贵,郁明敏,杨本法,李涛,张粤磊. Python深度学习(基于Pytorch). 北京:机械工业出版社,2019.