torch.gather函数的简单理解与使用

功能:根据索引来对高维tensor进行选择
要求:

  • input tensor 与 index 的 dim一致
  • index.shape < input.shape
torch.gather(input, dim, index) → Tensor

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
import torch

# [[1,2,3],
#  [4,5,6],
#  [7,8,9]]
input = torch.range(1, 9).view(3, 3)

示例1

#====================== dim=0 索引======================#
# 表示提供采样索引维度的是dim=0
dim = 0		
# index的维度为 [1, 3], 说明输出维度也为 [1, 3], index每个元素的索引为 [[(0, 0), (0, 1), (0, 2)]]
index = torch.tensor([[2, 1, 0]])
# 用index 来替换index的索引列表中的dim=0 得到: [[(2, 0), (1, 1), (0, 2)]]
output = torch.gather(input, dim, index)
# 将input索引为 (2, 0), (1, 1), (0, 3) 取出来就是 [[7, 5, 3]] 	

torch.gather函数的简单理解与使用_第1张图片


示例2

#======================== dim=1 索引=====================#
# 表示提供的采样索引维度dim=1
dim = 1
# index的维度为 (1, 3), 也就是index每个元素的索引为    [[(0, 0), (0, 1), (0, 2)]], 
index = torch.tensor([[2, 1, 0]])
# 用index的取值来替代 index的索引列表中的dim=1的元素得:[[(0,2) (0,1) (0,0)]]
# 将input索引为 (0,2) (0,1) (0,0) 取出来就是[[3, 2, 1]]
optput = torch.gather(input, dim, index)

torch.gather函数的简单理解与使用_第2张图片


示例3

#=============================================#
dim = 1	# 表示采样索引为1
index = torch.tensor([[2],
		 			  [1],
		 			  [0]])
# index的索引为[(0, 0),
			   (1, 0),
			   (2, 0)]
# 使用index的其余维索引来补全后得到:
'''
 [(0, 2),
  (1, 1),
  (2, 0)]
'''
# 对input索引
output = torch.gather(input, dim, index)
'''
[[3],
 [5],
 [7]]
'''

torch.gather函数的简单理解与使用_第3张图片


示例4

#====================== 多维index =====================#
dim = 1
index = torch.tensor([[0, 2], 
                      [1, 2]])
# index 的索引为 [[(0,0), (0,1)], 
# 				 [(1,0), (1,1)]]
# 用除了1维以外的索引将index补全得:
# [[(0,0), (0,2)], 
#  [(1,1), (1,2)]]

# 对input索引
output = torch.gather(input, dim, index)		  
#[[0, 3]
# [5, 6]]

torch.gather函数的简单理解与使用_第4张图片

更高维度的gather索引也是如此,先生成index每个元素的索引,再用index的值来替代dim维度的索引值,最后按照索引值到input中索引得到output

你可能感兴趣的:(pytorch踩坑日记,python,深度学习,人工智能)