pytorch中的gather函数_Pytorch中的torch.gather函数的理解

Pytorch中的torch.gather函数的理解

Pytorch中的torch.gather函数

pytorch比tensorflow更加编程友好,准备用pytorch试着做一些实验。

先看一下简单的用法示例代码,然后结合官方示例来解读:

b = torch.Tensor([[1,2,3],[4,5,6]])

print b

index_0 = torch.LongTensor([[1],[2]])

index_1 = torch.LongTensor([[0,1],[2,0]])

index_2 = torch.LongTensor([[0,1,1],[0,0,0]])

print (torch.gather(b, dim=1, index=index_0))

print (torch.gather(b, dim=1, index=index_1))

print (torch.gather(b, dim=0, index=index_2))

输出结果:

1 2 3

4 5 6

[torch.FloatTensor of size 2x3]

tensor([[2.],

[6.]])

1 2

6 4

[torch.FloatTensor of size 2x2]

1 5 6

1 2 3

[torch.FloatTensor of size 2x3]

结合上面的例子来看官方解读及示例,官方解读是给了三个公式:

torch.gather(input, dim, index, out=None) → Tensor

'''

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # dim=0

out[i][j][k] = input[i][index[i][j][k]][k] # dim=1

out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

Parameters:

input (Tensor) – The source tensor

dim (int) – The axis along which to index

index (LongTensor) – The indices of elements to gather

out (Tensor, optional) – Destination tensor

Example:

'''

>>> t = torch.Tensor([[1,2],[3,4]])

>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))

1 1

4 3

[torch.FloatTensor of size 2x2]

首先,可以看出output的形状和index的一致,且位置一 一对应。

dim=0时,

out[i][j][k] = input[index[i][j][k]][j][k]

out的取值为input[index[i][j][k]] [j] [k],为input值,output行(dim=0)的取值是index张量的元素值,列(dim=1)和index张量里面的列对应,(dim=2)维度也和index的一致。

例1:

b = torch.Tensor([[1,2,3],[4,5,6]])

index_2 = torch.LongTensor([[0,1,1],[0,0,0]])

print (torch.gather(b, dim=0, index=index_2))

'''

out[i][j][k] = input[index[i][j][k]][j][k] # dim=0

输出结果:

1 5 6

1 2 3

[torch.FloatTensor of size 2x3]

'''

一二三四五六

dim=0, index_2为两行三列的张量,output也为两行三列的张量。index中的【0,1,1】在第一行,【0,0,0】在第二行。

取值时,out第一个值取input第一行(由由【0,1,1】的0指定)第一列(这个是与index对应,0在第一列)的元素1;

第二个数取input第二行(由【0,1,1】的第一个1指定)第二列(这个是与index对应,第一个1在第二列)的元素5;

第三个数取input第二行(由【0,1,1】的第二个1指定)第三列(这个是与index对应,第二个1在第三列)的元素6。

第四个数取input第一行(由【0,0,0】的第一个0指定)第一列(这个是与index对应,第一个0在第一列)的元素1。

后面的同理。

dim=1时,

out[i][j][k] = input[i] [index[i][j][k]] [k] # dim=1

out的取值为input[i] [index[i][j][k]] [k],为input值,行(dim=0)和index张量里面的行对应,列(dim=1)是index张量的元素值,(dim=2)维度也和index的(dim=2)维度 对应。

例2:

b = torch.Tensor([[1,2,3],[4,5,6]])

index_1 = torch.LongTensor([[1,2],[2,0]])

print (torch.gather(b, dim=1, index=index_1))

'''

out[i][j][k] = input[i] [index[i][j][k]] [k] # dim=1

输出结果:

2 3

6 4

[torch.FloatTensor of size 2x2]

'''

一二三四

dim=1, index_1为两行两列的张量,output也为两行两列的张量。index中的【1,2】在第一行,【2,0】在第二行。

第一个数取input第一行(这个是与index对应,同在第一行)第二列(由【1,2】中的1指定 )的元素2,

第二个数取input第一行(这个是与index对应,同在第一行)第三列(由【1,2】中的2指定)的元素3。

第三个数取input第二行(这个是与index对应,同在第二行)第三列(由【2,0】中的2指定)的元素6。

第四个数取input第二行(这个是与index对应,同在第二行)第一列(由【2,0】中的0指定)的元素4。

还可以看出index的形状和input的形状是一致的,都是二维的,里面的index数值不能超过input的界限,比如行的不能超过1,列的不能超过2。

理解了这几个式子也就记住了这个方法的用法。

你可能感兴趣的:(pytorch中的gather函数_Pytorch中的torch.gather函数的理解)