选择函数 torch.gather()的理解

【时间】2019.03.19

【题目】选择函数 torch.gather()的理解

1、Pytorch中的torch.gather函数的含义
2、pytorch之torch.gather方法

torch.gather(input, dim, index, out=None) → Tensor

【注意】

返回的tensor的size与index的size一致。

dim用于指明index的元素值代表的维数。这个函数可以用来很方便地提取方阵

的对角元素。比如:

import torch as t
a = t.arange(0, 16).view(4, 4)

index = t.LongTensor([[0,1,2,3]])

b = t.gather(a,0, index)
print(a)
print(index)
print(b)

【运行结果】:表示从a中提取shape为index的shape(即1X2)的tensor,dim = 0指明index元素值的维度,所以所选择的完整索引值为[[(0,0),(1,1),(2,2),(3,3)]],即最终结果b取a的对角线。

选择函数 torch.gather()的理解_第1张图片
官方给出的解释是这样的:
                    沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
                    对一个3维张量,输出可以定义为:

                out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0

                out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1

                out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3

你可能感兴趣的:(pytorch,torch.gather)