torch.squeeze
torch.squeeze(input, dim=None, out=None) → Tensor
分为两种情况: 不指定维度 或 指定维度
不指定维度
input: (A, B, 1, C, 1, D) output: (A, B, C, D)
Example
>>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x) >>> y.size() torch.Size([2, 2, 2])
指定维度
input: (A, 1, B)
&torch.squeeze(input, 0)
output: (A, 1, B)
input: (A, 1, B)
&torch.squeeze(input, 1)
output: (A, B)
Example
>>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 0) >>> y.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 1) >>> y.size() torch.Size([2, 2, 1, 2])
Note:
The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.
也就是说, squeeze 所返回的 tensor 与 输入 tensor 共享内存, 所以如果改变其中一项的值另一项也会随着改变.
torch.unsqueeze
torch.unsqueeze(input, dim, out=None) → Tensor
Note: 这里与 squeeze 不同的是 unsqueeze 必须指定维度.
同时, unsqueeze 所返回的 tensor 与 输入的 tensor 也是共享内存的.
>>> import torch
>>> a = torch.zeros([2, 2])
>>> a.shape
torch.Size([2, 2])
>>> a
tensor([[0., 0.],
[0., 0.]])
>>> b = torch.unsqueeze(a, dim=0)
>>> b.shape
torch.Size([1, 2, 2])
>>> b
tensor([[[0., 0.],
[0., 0.]]])
>>> b[0, 0, 1] = 1
>>> b
tensor([[[0., 1.],
[0., 0.]]])
>>> a
tensor([[0., 1.],
[0., 0.]])
>>> b = torch.unsqueeze(a, dim=1)
>>> b.shape
torch.Size([2, 1, 2])
>>> b
tensor([[[0., 1.]],
[[0., 0.]]])