本文是对torch.unsqueeze官方文档的解读,加入部分代码示例,方便理解。
torch.unsqueeze
(input, dim)
Returns a new tensor with a dimension of size one inserted at the specified position.
返回一个新的tensor,在input的指定位置插入一维。
先别管这个官方的文档,只看文字很难理解什么是插入一个维度,一看代码举例就懂了。
x = torch.tensor([1, 2, 3, 4]) # 这里定义一个一维向量
print(torch.unsqueeze(x, 1))。 # 这里在dim=1这个位置插入一维,看效果
tensor([[1],
[2],
[3],
[4]])
The returned tensor shares the same underlying data with this tensor.
返回的tensor和输入tensor共享底层数据。
A
dim
value within the range[-input.dim() - 1, input.dim() + 1)
can be used. Negativedim
will correspond to unsqueeze() applied atdim
=dim + input.dim() + 1
.
dim的取值范围是[-input.dim() - 1, input.dim() + 1),注意左闭右开。
如果dim取负值,那么最终影响的dim是dim+input.dim() + 1。
x = torch.tensor([1, 2, 3, 4])
print(torch.unsqueeze(x, 0))
tensor([[1, 2, 3, 4]])
# dim=-1, 实际影响了 -1 + 1 + 1 = 1
print(torch.unsqueeze(x, -1))
tensor([[1],
[2],
[3],
[4]])
# 所以 dim=1 和 dim=-1 效果一样
print(torch.unsqueeze(x, 1))
tensor([[1],
[2],
[3],
[4]])
# dim = -2,实际影响了 -2 + 1 + 1 = 0, 所以和dim=0结果一样
print(torch.unsqueeze(x, -2))
tensor([[1, 2, 3, 4]])
对于一个一维向量,增加维度有两个方向。
dim=0这个方向,相当于直接把整个向量作为高维张量的一个值,可以说是水平扩展。
dim=1这个方向,相当于让原向量的每个维度独立成一个向量,可以说垂直扩展。
• input (Tensor) – 输入张量
• dim (int) – 指定在哪个index插入一个维度的
下面的代码展示了水平扩展。
d2 = torch.arange(20, dtype=torch.float32).reshape(5, 4)
print(d2)
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.]])
d3 = d2.unsqueeze(0)
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.]]])
print(d3[0])
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.]])