torch gather函数理解 图解

看了好几篇了,没有直接看明白,特梳理之

功能

数据收集,函数torch.gather(input, dim, index, out=None) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.

理解

对于一个shape为(3,4)的数据a,可用索引a[1,2]取 a[1][2]数据;即通过a[i,j]的方式可以获取数据,gather则通过类似方式收集数据。

通过示例理解gather具体的方式,gather则内容更丰富。

结论

Gather获取的数据shape和idx索引的shape一样,获取数据的内容为data中idx对应的坐标,idx对应的坐标按指定轴被idx坐标处值替换后的坐标。

idx数组的shape不一定小于data的,只要对应的坐标构造后合法即可。

看后面示例

1维数据

dim 只能取 0;

>>> a=torch.arange(1,12,2) 
>>> a
tensor([ 1,  3,  5,  7,  9, 11])
>>> idx = torch.tensor([1,3,5]) 
>>> a.gather(0,idx)
tensor([ 3,  7, 11])

取第dim=0,的第1,3,5个数据,可以读取任意位置 任意数量的元素。

torch gather函数理解 图解_第1张图片

2维数据

dim能取0,1;

0轴取数据示例,通过构造的坐标理解gather 的原理

torch gather函数理解 图解_第2张图片

 1轴取数据示例

torch gather函数理解 图解_第3张图片

 3维数据

3维数据0轴

torch gather函数理解 图解_第4张图片

 3维数据1轴torch gather函数理解 图解_第5张图片

 3维数据2轴

torch gather函数理解 图解_第6张图片

应用: 我从RL中看到这个函数,其他的应用大家补充⑧

你可能感兴趣的:(pytorch,gather)