【函数小Trick】torch.gather(获取高维数据/矩阵/数组特定位置值)

1.背景

可以先看torch官方文档介绍

【函数小Trick】torch.gather(获取高维数据/矩阵/数组特定位置值)_第1张图片

主要作用是根据索引值index,找出input向量中指定dim维度所对应的数值,熟练使用该函数,就不用暴力for循环啦。

2.函数应用场景

(1)自然语言处理中的mask与padding位置

import torch
a= torch.Tensor([
    [4,1,2,0,0],
    [2,4,0,0,0],
    [1,1,1,6,5],
    [1,2,2,2,2],
    [3,0,0,0,0],
    [2,2,0,0,0]])
index = torch.LongTensor([[3],[2],[5],[5],[1],[2]])
print(a.size(),index.size())
b = torch.gather(a, 1,index-1)
print(b)

注意维度的选择:

【函数小Trick】torch.gather(获取高维数据/矩阵/数组特定位置值)_第2张图片

(2)神经网络特定类的预测值

import torch
a= torch.Tensor([
    [0.4,0.1,0.2,0.,0.3],
    [0.2,0.4,0.2,0.1,0.1],
    [0.1,0.1,0.1,0.6,0],
    [0.1,0.3,0.2,0.2,0.2],
    [0.3,0.0,0.7,0,0],
    [0.2,0.2,0.0,0.5,0.1]])
index = torch.LongTensor([[1],[2],[4],[2],[3],[4]])
print(a.size(),index.size())
b = torch.gather(a, 1,index-1)
print(b)

输出:

【函数小Trick】torch.gather(获取高维数据/矩阵/数组特定位置值)_第3张图片

你可能感兴趣的:(Python#函数小trick,python,算法,人工智能,gather,特定位置输出)