深度学习笔记(三)——对Tensor索引操作函数gather的解释

索引操作函数gather函数详解

  • gather函数的输出规则
    • 第一条规则
    • 第二条规则
    • gather函数内部的代码机理推测
  • 代码示例
    • 输出结果
      • 参考文献

事先声明:本文只会对二维张量的gather操作进行介绍,三维张量的gather操作规则在csdn上的博文屡见不鲜。本文的解释是从个人的理解出发,相信解释也会对理解三维张量的操作规则起到触类旁通的作用。

gather函数的输出规则

o u t [ i ] [ j ] = i n p u t [ i n d e x [ i ] [ j ] ] [ j ] , i f   d i m = = 0 out [i] [j] = input [index [i] [j] ] [j], if{\ }dim == 0 out[i][j]=input[index[i][j]][j],if dim==0
o u t [ i ] [ j ] = i n p u t [ i ] [ i n d e x [ i ] [ j ] ] , i f   d i m = = 1 out [i] [j] = input [i] [index [i] [j] ], if{\ }dim == 1 out[i][j]=input[i][index[i][j]],if dim==1

第一条规则

从行的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个 2 ∗ 3 2*3 23的维度的张量,index为 [ 0 , 1 , 1 ] [0, 1, 1] [0,1,1],取 d i m = 0 dim=0 dim=0,根据规则,外层循环为变量 i i i,内层循环为变量 j j j,且 i   i n   r a n g e ( 0 , 2 ) ; j   i n   r a n g e ( 0 , 3 ) i {\ }in{\ } range(0, 2); j{\ } in{\ } range(0, 3) i in range(0,2);j in range(0,3)
代入 i = 0 ,   j = 1 i=0,{\ }j=1 i=0, j=1,得到:
o u t [ 0 ] [ 1 ] = i n p u t [ i n d e x [ 0 ] [ 1 ] ] [ 1 ] out[0][1]=input[index[0][1]][1] out[0][1]=input[index[0][1]][1]
o u t [ 0 ] [ 1 ] = i n p u t [ 1 ] [ 1 ] out[0][1]=input[1][1] out[0][1]=input[1][1]
即:该输出元素为输入的 2 ∗ 3 2*3 23维度张量的第1行第1列元素。且该元素在输出张量中处在第0行第1列的位置。
如下表所示:

0 1 2
0
1 this element

其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。

第二条规则

从列的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个 2 ∗ 3 2*3 23的维度的张量,index为 [ [ 0 , 1 , 1 ] , [ 1 , 1 , 1 ] ] [[0, 1, 1],[1, 1, 1]] [[0,1,1],[1,1,1]],取 d i m = 1 dim=1 dim=1,根据规则,外层循环为变量 i i i,内层循环为变量 j j j,且 i   i n   r a n g e ( 0 , 2 ) ; j   i n   r a n g e ( 0 , 3 ) i {\ }in{\ } range(0, 2); j{\ } in{\ } range(0, 3) i in range(0,2);j in range(0,3)
代入 i = 1 ,   j = 1 i=1,{\ }j=1 i=1, j=1,得到:
o u t [ 1 ] [ 1 ] = i n p u t [ 1 ] [ i n d e x [ 1 ] [ 1 ] ] out[1][1]=input[1][index[1][1]] out[1][1]=input[1][index[1][1]]
o u t [ 1 ] [ 1 ] = i n p u t [ 1 ] [ 1 ] out[1][1]=input[1][1] out[1][1]=input[1][1]
即:该输出元素为输入的 2 ∗ 3 2*3 23维度张量的第1行第1列元素,且该元素在输出张量中处在第1行第1列的位置。
如下表所示:

0 1 2
0
1 this element

其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。

gather函数内部的代码机理推测

声明:下述代码仅针对原理部分编写,距离函数内部真实情况仍存在较大差距,且下述代码的严谨性不够,故仅供理解gather的核心规则。

def gather(input, dim, index):
# 这里的dim要求取0或1
	out = []
	m = input.size()[0] # size函数是torch的方法
	n = input.size()[1]
	for i in range(m):
		for j in range(n):
			if dim == 0:
				out [i] [j] = input [index [i] [j] ] [j]
			if dim == 1:
				out [i] [j] = input [i] [index [i] [j] ]
	return out

代码示例

与上一篇博文内容相同,这里再次展示一遍。

import torch
# 设置一个随机种子
torch.manual_seed(100)
# 生成一个形状为2*3的矩阵
x = torch.randn(2, 3)
print(x)
# 获取指定索引对应的值
index = torch.LongTensor([[0, 1, 1]])
print(torch.gather(x, 0, index))
index = torch.LongTensor([[0, 1, 1], [1, 1, 1]])
a = torch.gather(x, 1, index)
print(a)

输出结果

深度学习笔记(三)——对Tensor索引操作函数gather的解释_第1张图片

参考文献

吴茂贵,郁明敏,杨本法,李涛,张粤磊. Python深度学习(基于Pytorch). 北京:机械工业出版社,2019.

你可能感兴趣的:(深度学习(基于pytorch),深度学习,pytorch)