PyTorch中squeeze()和unsqueeze()函数理解

squeeze(arg)
表示若第arg维的维度值为1,则去掉该维度,否则tensor不变。(即若tensor.shape()[arg] == 1,则去掉该维度)
例如:
一个维度为2x1x2x1x2的tensor,不用去想它长什么样儿,squeeze(0)就是不变,squeeze(1)就是变成2x2x1x2。(0是从最左边的维度算起的)

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

unsqueeze(arg)
与squeeze(arg)函数作用相反,表示在第arg维增加一个维度为1的维度。
啥意思呢?
比如一个tensor的shape为3x3,那么unsqueeze(0)就是变成1x3x3,unsqueeze(1)就是变成3x1x3.
再如下面这个官方的例子,得看好几眼才能看明白怎么回事。
其实可以这样理解:x的shape为:4,unsqueeze(0)就是把shape变成1x4;unsqueeze(1)就是把shape变成4x1。

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

参考:
[1] https://pytorch.org/docs/1.11/generated/torch.unsqueeze.html#torch.unsqueeze
[2] https://www.cnblogs.com/sbj123456789/p/9231571.html

你可能感兴趣的:(深度学习,pytorch,深度学习,python)