torch.unsqueeze官方文档解读

文章目录

  • 前言
  • 1. 功能介绍
  • 2. 参数说明
  • 3. 代码举例

前言

本文是对torch.unsqueeze官方文档的解读,加入部分代码示例,方便理解。

1. 功能介绍

torch.unsqueeze(input, dim)
Returns a new tensor with a dimension of size one inserted at the specified position.

返回一个新的tensor,在input的指定位置插入一维。
先别管这个官方的文档,只看文字很难理解什么是插入一个维度,一看代码举例就懂了。

x = torch.tensor([1, 2, 3, 4])  # 这里定义一个一维向量
print(torch.unsqueeze(x, 1))# 这里在dim=1这个位置插入一维,看效果
tensor([[1],
        [2],
        [3],
        [4]])

The returned tensor shares the same underlying data with this tensor.

返回的tensor和输入tensor共享底层数据。

A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.

dim的取值范围是[-input.dim() - 1, input.dim() + 1),注意左闭右开。
如果dim取负值,那么最终影响的dim是dim+input.dim() + 1。

x = torch.tensor([1, 2, 3, 4])
print(torch.unsqueeze(x, 0))
tensor([[1, 2, 3, 4]])

# dim=-1, 实际影响了 -1 + 1 + 1 = 1
print(torch.unsqueeze(x, -1))
tensor([[1],
        [2],
        [3],
        [4]])
# 所以 dim=1 和 dim=-1 效果一样
print(torch.unsqueeze(x, 1))
tensor([[1],
        [2],
        [3],
        [4]])
# dim = -2,实际影响了 -2 + 1 + 1 = 0, 所以和dim=0结果一样
print(torch.unsqueeze(x, -2))
tensor([[1, 2, 3, 4]])

对于一个一维向量,增加维度有两个方向。
dim=0这个方向,相当于直接把整个向量作为高维张量的一个值,可以说是水平扩展。
dim=1这个方向,相当于让原向量的每个维度独立成一个向量,可以说垂直扩展。

2. 参数说明

• input (Tensor) – 输入张量
• dim (int) – 指定在哪个index插入一个维度的

3. 代码举例

下面的代码展示了水平扩展。

d2 = torch.arange(20, dtype=torch.float32).reshape(5, 4)
print(d2)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.],
        [16., 17., 18., 19.]])

d3 = d2.unsqueeze(0)
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.]]])

print(d3[0])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.],
        [16., 17., 18., 19.]])

你可能感兴趣的:(深度学习,pytorch,深度学习)