torch.gather()使用解析

前言

这个本来是没打算写的,因为看了官方的解释以及在网上看了好几个教程都没理解什么意思,所以把自己理解的东西整理分享一下。

官方的解释

官网链接:torch.gather()
给个截图如下
torch.gather()使用解析_第1张图片
常用的参数有3个,第一个input表示要从中选取元素,第二个dim表示操作的维度,第三个index表示选取元素的索引。
按照官方的解释我是没看懂的,后面去找教程也一知半解,所以自己琢磨了一下,终于悟了。

使用详解

结合着例子,直接看代码把:

import torch

a = torch.arange(3, 12).view(3, 3)
print(a)
# tensor([[ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
# tensor([[9, 7, 5]])
# 1、将index中的各个元素的索引明确,获得具体坐标:
#    index = torch.tensor([[2, 1, 0]])中,
#    2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)
# 2、将具体坐标中对应的维度替换成index中的值:
#    2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
#    1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
#    0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0
# 3、按照新的坐标取输入中的值:
# tensor([[ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].

index = torch.tensor([[2, 1, 0]])
c = torch.gather(a, dim=1, index=index)
print(c)    # tensor([[5, 4, 3]])
# 1、获取具体坐标:(0,0),(0,1),(0,2)
# 2、第1维度替换坐标:(0,2),(0,1),(0,0)
# 3、找元素:[5,4,3]

# 二维的情况也一样
index = torch.tensor([[0, 2],
                      [1, 2]])
d = torch.gather(a, dim=1, index=index)
print(d)
# tensor([[3, 5],
#         [7, 8]])
# 1、获取具体坐标:(0,0),(0,1),(1,0),(1,1)
# 2、第1维度替换坐标:(0,0),(0,2),(1,1),(1,2)
# 3、找元素:[[3, 5],[7, 8]]

怕在代码里面太暗了看不清楚,在这里再贴一次:
以第一个为例:

创建张量
a = torch.arange(3, 12).view(3, 3)
print(a)
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
a 的值如下:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
b 的值为tensor([[9, 7, 5]])
具体过程:

1、将index中的各个元素的索引明确,获得具体坐标:
index = torch.tensor([[2, 1, 0]])中,
2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)

2、将具体坐标中对应的维度替换成index中的值:
2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0

3、按照新的坐标取输入中的值:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].

实战操作

torch.gather()这个函数通常用在批量的获取张量的某些数据,比如说,要获取一个大小为(b, n, c)的张量中的多个不连续的索引行向量,这种操作通常在下采样的过程中会用到。按照正常的思路,实现这样的操作需要写几个for循环,但for循环在训练时特别的慢,因此可以使用torch.gather来实现这一功能。
具体例子:

def test_gather():
    b, n, c = 4, 3, 3
    k = 2   # 下采样的个数
    a = torch.rand((b, n, c))   # 定义输入数据
    print('a:', a)
    idx = torch.randint(low=0, high=n, size=(b, k))     # (b, k),生成随机索引
    print('idx:', idx)
    # 进行维度扩展和复制
    new_idx = idx.unsqueeze(-1).expand(-1, -1, 3)
    print('idx after expand', new_idx)
    # 关键语句,按照一定的维度来取出数据
    b = torch.gather(a, dim=1, index=new_idx)
    print('b', b)
    # 这个是for循环版本的操作
    c = torch.stack(
        [a[i][idx[i], :] for i in range(len(a))]
    )

    print(b == c)

打印信息如下:

a: tensor([[[0.8053, 0.7751, 0.7346],
         [0.4371, 0.1006, 0.6389],
         [0.9040, 0.1699, 0.3022]],

        [[0.7410, 0.5656, 0.9189],
         [0.4067, 0.4953, 0.1776],
         [0.9622, 0.0738, 0.3553]],

        [[0.5321, 0.9538, 0.5806],
         [0.2257, 0.7163, 0.7548],
         [0.2393, 0.4100, 0.2497]],

        [[0.2234, 0.9685, 0.7388],
         [0.7087, 0.0933, 0.7147],
         [0.1741, 0.0103, 0.6587]]])
idx: tensor([[0, 2],
        [2, 2],
        [0, 0],
        [0, 1]])
idx after expand tensor([[[0, 0, 0],
         [2, 2, 2]],

        [[2, 2, 2],
         [2, 2, 2]],

        [[0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 1, 1]]])
b tensor([[[0.8053, 0.7751, 0.7346],
         [0.9040, 0.1699, 0.3022]],

        [[0.9622, 0.0738, 0.3553],
         [0.9622, 0.0738, 0.3553]],

        [[0.5321, 0.9538, 0.5806],
         [0.5321, 0.9538, 0.5806]],

        [[0.2234, 0.9685, 0.7388],
         [0.7087, 0.0933, 0.7147]]])
tensor([[[True, True, True],
         [True, True, True]],

        [[True, True, True],
         [True, True, True]],

        [[True, True, True],
         [True, True, True]],

        [[True, True, True],
         [True, True, True]]])

Process finished with exit code 0

再来详细的看一下是怎么选取数据的:
以b中的第一个行向量为例子,看一下是怎么得到的
1、明确坐标,idx[0][0] = 0,也就是说1的坐标为(0, 0),由于这里0代表的是第一行向量,但是在torch.gather中要精确到具体的坐标,0只是代表了一个维度而且,还有两个维度,因此要将其进行维度扩展和复制:

idx.unsqueeze(-1).expand(-1, -1, 3)	# unsqueeze函数进行维度扩张,expand对最后一个维度进行复制

具体效果可以从打印的信息看出:

维度扩展之前:idx: tensor([[0, 2],
维度扩展之前:idx after expand tensor([[[0, 0, 0],
         [2, 2, 2]],

2、替换坐标,因此索引0的值在扩张之后是(0, 0, 0),索引2为(2, 2, 2)由于在本次的代码中torch.gather是对第一维度进行操作 b = torch.gather(a, dim=1, index=new_idx),因此对第一个维度进行变换,因此对第一个维度进行替换,取值过程如下,(0, 0, 0)的对应坐标为(0, 0, 0), (0, 0, 1), (0, 0, 2),对第一个维度进行替换得到new_idx = (0, 0, 0), (0, 0, 1), (0, 0, 2),因此就可以根据这new_idx去a中取值,也就是a[0][0][:]。同样的道理,(2, 2, 2)的坐标为(0, 1, 0), (0, 1, 1),(0, 1, 2) -> (0, 2, 0), (0, 2, 1),(0, 2, 2),因此根据这三个索引可以取出a中的值,输出结果可以从打印信息看出:

a[0][0] = [0.8053, 0.7751, 0.7346]
a[0][2] = [0.9040, 0.1699, 0.3022]
对应的取出的值为:
b tensor([[[0.8053, 0.7751, 0.7346],
         [0.9040, 0.1699, 0.3022]],

参考链接

图解PyTorch中的torch.gather函数
Pytorch系列(1):torch.gather()
pytorch之torch.gather方法
pytorch中的所有随机数(normal、rand、randn、randint、randperm) 以及 随机数种子(seed、manual_seed、initial_seed)

结束语

文章为分享、记录、整理自己的经历情况,水平有限,如有错误之处敬请指出。

你可能感兴趣的:(实用技巧,深度学习,pytorch,人工智能)