[pytorch]gather、stack函数

gather函数:gather API

这个函数在MAE模型中的random-masking函数中也有应用。
gather函数可以理解为根据对应的索引从原始的tensor中选择tensor,首先来看2-D的情况:

t = torch.tensor([[1,2],[3,4]])
torch.gather(t,1,torch.tensor([[0,0],[1,0]])) 
#tensor([[1, 1],
#        [4, 3]])

dim对应变化的索引位置,在上面的例子中dim=1,这就意味着是从t张量的维度1去选择的,从t[dim0][dim1]。选择的索引对应[0,0]和[1,0]这就是说是dim0=0的时候选择t张量中dim0=0 dim1为给出的索引的元素,本例中给出的索引为[0,0],也就是说第一行的输出为t[0][0]t[0][0]即都是元素1,同理在dim=1的时候就是t[1][1]t[1][0],也就是4,3。

再看个3D的例子:

t = torch.range(0,23).reshape(2,3,4)

'''
t:tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
 '''
select = torch.tensor([[[2,1,0,0,],[2,0,0,0],[1,1,2,2]],[[2,1,1,2],[2,0,0,0],[1,1,2,2]]])

# select.shape # the same as t
torch.gather(t,dim=1,index=select) 
'''
tensor([[[ 8.,  5.,  2.,  3.],
         [ 8.,  1.,  2.,  3.],
         [ 4.,  5., 10., 11.]],

        [[20., 17., 18., 23.],
         [20., 13., 14., 15.],
         [16., 17., 22., 23.]]])
'''
        

从上面的例子可以看出dim=1,所以对dim=1的元素进行选择,例如2,1,0,0对应的就是t张量在dim=1的时候

[[ 0.,  1.,  2.,  3.],
 [ 4.,  5.,  6.,  7.],
 [ 8.,  9., 10., 11.]],

分别按照索引选即可。

看下MAE中的例子。

def random_masking(x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

测试:

x = torch.randn(2,4,3)
x_masked, mask, ids_restore = random_masking(x,0.75)

在进入gather函数之前,各个张量的维度情况如下:
x:[2,4,3] [B,Token_length,Embed_dim]
ids_keep:[2,1] [B,Token_length(1-mask_ratio)]
ids_keep.unsqueeze(-1).repeat(1, 1, D):[2,1,3] [B,Token_length(1-mask_ratio),Embed_dim]

输出情况:

x tensor([[[-0.5777, -0.1997,  0.0505],
         [-0.7627,  0.6580,  0.4952],
         [-0.5960, -1.1513,  1.3593],
         [ 1.0174,  0.8969,  1.9184]],

        [[ 0.2932,  1.5911,  0.0404],
         [ 0.9809,  0.8083, -0.9814],
         [ 1.3939,  1.2482, -2.4358],
         [ 0.0136,  1.0753, -1.2225]]])
         
ids_keep.unsqueeze(-1).repeat(1, 1, D):tensor([[[2, 2, 2]],

        [[0, 0, 0]]])

x_masked tensor([[[-0.5960, -1.1513,  1.3593]],

        [[ 0.2932,  1.5911,  0.0404]]])

来进行分析,是从原始的x中根据mask后留下的张量来进行选择。从x中进行选择,从dim=1的地方选择,因为ids_keep为2,0,也就是说第一组x的dim1中的选索引为2的那个,即[-0.5960, -1.1513, 1.3593],另一个也是类似的。

x_masked:[2,1,3] 为掩码后没有被掩码的部分。

stack函数:stack API

stack函数不同于cat函数的直接拼接,会在维度上产生变化

a = torch.randn(3,2)
b = torch.randn(3,2)
print(torch.cat([a,b],dim=1).shape) #[3,4]
print(torch.stack([a,b],dim=1).shape)#[3,2,2]

stack类似于一个栈,依次压入了两个元素。

a:
tensor([[-1.3514,  0.3787],
        [-1.1666,  1.6129],
        [ 0.1213, -1.1316]])
b:
tensor([[ 0.7900,  0.8625],
        [-0.3513, -0.0850],
        [ 1.3164,  2.7623]])
torch.cat([a,b],dim=1):
tensor([[-1.3514,  0.3787,  0.7900,  0.8625],
        [-1.1666,  1.6129, -0.3513, -0.0850],
        [ 0.1213, -1.1316,  1.3164,  2.7623]])  
torch.stack([a,b],dim=1):   
tensor([[[-1.3514,  0.3787],
         [ 0.7900,  0.8625]],

        [[-1.1666,  1.6129],
         [-0.3513, -0.0850]],

        [[ 0.1213, -1.1316],
         [ 1.3164,  2.7623]]])      

从上面的例子可以看出cat是在对应的维度上直接的拼接,而stack则是分别从每个元素中取出一个再进行堆叠,其多出的那个维度就是设置的dim,并且个数就是进行操作的张量个数,本例中为2 因为是将a,b两个张量去进行了stack操作。若改为4,则shape对应[2,4,3]:

a = torch.randn(3,2)
b = torch.randn(3,2)
c = torch.randn(3,2)
d = torch.randn(3,2)
print(torch.stack([a,b,c,d],dim=1).shape) #torch.Size([3, 4, 2])

你可能感兴趣的:(深度学习框架笔记,pytorch,深度学习,python)