gather
函数:gather API这个函数在MAE模型中的random-masking函数中也有应用。
gather函数可以理解为根据对应的索引从原始的tensor中选择tensor,首先来看2-D的情况:
t = torch.tensor([[1,2],[3,4]])
torch.gather(t,1,torch.tensor([[0,0],[1,0]]))
#tensor([[1, 1],
# [4, 3]])
dim对应变化的索引位置,在上面的例子中dim=1,这就意味着是从t张量的维度1去选择的,从t[dim0][dim1]
。选择的索引对应[0,0]和[1,0]这就是说是dim0=0的时候选择t张量中dim0=0 dim1为给出的索引的元素,本例中给出的索引为[0,0],也就是说第一行的输出为t[0][0]
和t[0][0]
即都是元素1,同理在dim=1的时候就是t[1][1]
和t[1][0]
,也就是4,3。
再看个3D的例子:
t = torch.range(0,23).reshape(2,3,4)
'''
t:tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
'''
select = torch.tensor([[[2,1,0,0,],[2,0,0,0],[1,1,2,2]],[[2,1,1,2],[2,0,0,0],[1,1,2,2]]])
# select.shape # the same as t
torch.gather(t,dim=1,index=select)
'''
tensor([[[ 8., 5., 2., 3.],
[ 8., 1., 2., 3.],
[ 4., 5., 10., 11.]],
[[20., 17., 18., 23.],
[20., 13., 14., 15.],
[16., 17., 22., 23.]]])
'''
从上面的例子可以看出dim=1,所以对dim=1的元素进行选择,例如2,1,0,0对应的就是t张量在dim=1的时候
[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
分别按照索引选即可。
看下MAE中的例子。
def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
测试:
x = torch.randn(2,4,3)
x_masked, mask, ids_restore = random_masking(x,0.75)
在进入gather函数之前,各个张量的维度情况如下:
x
:[2,4,3] [B,Token_length,Embed_dim]
ids_keep
:[2,1] [B,Token_length(1-mask_ratio)]
ids_keep.unsqueeze(-1).repeat(1, 1, D)
:[2,1,3] [B,Token_length(1-mask_ratio),Embed_dim]
输出情况:
x tensor([[[-0.5777, -0.1997, 0.0505],
[-0.7627, 0.6580, 0.4952],
[-0.5960, -1.1513, 1.3593],
[ 1.0174, 0.8969, 1.9184]],
[[ 0.2932, 1.5911, 0.0404],
[ 0.9809, 0.8083, -0.9814],
[ 1.3939, 1.2482, -2.4358],
[ 0.0136, 1.0753, -1.2225]]])
ids_keep.unsqueeze(-1).repeat(1, 1, D):tensor([[[2, 2, 2]],
[[0, 0, 0]]])
x_masked tensor([[[-0.5960, -1.1513, 1.3593]],
[[ 0.2932, 1.5911, 0.0404]]])
来进行分析,是从原始的x中根据mask后留下的张量来进行选择。从x中进行选择,从dim=1的地方选择,因为ids_keep为2,0,也就是说第一组x的dim1中的选索引为2的那个,即[-0.5960, -1.1513, 1.3593],另一个也是类似的。
x_masked
:[2,1,3] 为掩码后没有被掩码的部分。
stack
函数:stack APIstack函数不同于cat函数的直接拼接,会在维度上产生变化
a = torch.randn(3,2)
b = torch.randn(3,2)
print(torch.cat([a,b],dim=1).shape) #[3,4]
print(torch.stack([a,b],dim=1).shape)#[3,2,2]
stack类似于一个栈,依次压入了两个元素。
a:
tensor([[-1.3514, 0.3787],
[-1.1666, 1.6129],
[ 0.1213, -1.1316]])
b:
tensor([[ 0.7900, 0.8625],
[-0.3513, -0.0850],
[ 1.3164, 2.7623]])
torch.cat([a,b],dim=1):
tensor([[-1.3514, 0.3787, 0.7900, 0.8625],
[-1.1666, 1.6129, -0.3513, -0.0850],
[ 0.1213, -1.1316, 1.3164, 2.7623]])
torch.stack([a,b],dim=1):
tensor([[[-1.3514, 0.3787],
[ 0.7900, 0.8625]],
[[-1.1666, 1.6129],
[-0.3513, -0.0850]],
[[ 0.1213, -1.1316],
[ 1.3164, 2.7623]]])
从上面的例子可以看出cat是在对应的维度上直接的拼接,而stack则是分别从每个元素中取出一个再进行堆叠,其多出的那个维度就是设置的dim,并且个数就是进行操作的张量个数,本例中为2 因为是将a,b两个张量去进行了stack操作。若改为4,则shape对应[2,4,3]:
a = torch.randn(3,2)
b = torch.randn(3,2)
c = torch.randn(3,2)
d = torch.randn(3,2)
print(torch.stack([a,b,c,d],dim=1).shape) #torch.Size([3, 4, 2])