【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读

目录

【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读 

unsqueeze() 函数的作用:

语法:

unsqueeze() 操作示例:

示例 1:将一个一维张量转换为二维张量

示例 2:在最后一维插入一个新维度

示例 3:负索引插入维度

示例 4:将二维张量转为三维张量

总结:


【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读 

在 PyTorch 中,unsqueeze() 是一个非常实用的函数,用于在张量的指定位置插入一个维度。

简而言之,unsqueeze() 通过增加一个长度为1的维度来扩展张量的维度。

unsqueeze() 函数的作用:

unsqueeze() 函数将一个张量的维度增加 1

这个函数常用于调整张量的形状,特别是在需要将一个二维或一维张量转换为更高维度的张量时。

语法:

torch.unsqueeze(input, dim)
  • input:输入张量。
  • dim:指定要插入新维度的位置。dim 是一个整数,表示新维度的位置,取值范围是 [-input.dim() - 1, input.dim()]。如果 dim 为负数,它表示从最后一个维度开始计数。

unsqueeze() 操作示例:

示例 1:将一个一维张量转换为二维张量

假设我们有一个一维张量 [1, 2, 3],我们希望通过 unsqueeze() 将其转换为一个二维张量,并在第 0 维度(最前面)插入一个新的维度。

import torch

# 创建一个一维张量
x = torch.tensor([1, 2, 3])

# 在第0维插入一个新的维度
y = torch.unsqueeze(x, 0)

print("Original shape:", x.shape)  # 原始张量形状
print("New shape:", y.shape)  # 新张量形状
print(y)

输出:

Original shape: torch.Size([3])
New shape: torch.Size([1, 3])
tensor([[1, 2, 3]])
  • 解释
    • 原始张量 x 的形状是 (3),表示这是一个包含 3 个元素的一维张量。
    • 使用 torch.unsqueeze(x, 0) 后,在张量的第 0 维插入了一个新的维度。结果是一个形状为 (1, 3) 的二维张量。
    • unsqueeze(0) 会在第一个维度(最前面)插入新的维度,表示这个张量现在有 1 行,3 列。
示例 2:在最后一维插入一个新维度

假设我们希望将张量 [1, 2, 3] 变成形状为 (3, 1) 的二维张量,我们可以在第 1 维(最后一维)插入一个新的维度。

# 在第1维插入一个新的维度
z = torch.unsqueeze(x, 1)

print("Original shape:", x.shape)
print("New shape:", z.shape)
print(z)

输出:

Original shape: torch.Size([3])
New shape: torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
  • 解释
    • 原始张量 x 的形状是 (3),是一个一维张量。
    • 使用 torch.unsqueeze(x, 1) 后,在第 1 维(即最后一个维度)插入了一个新的维度。结果是一个形状为 (3, 1) 的二维张量,表示这个张量现在有 3 行,1 列。
示例 3:负索引插入维度

我们可以使用负数索引来指定维度的位置。负数表示从最后一个维度开始计数。

# 在倒数第一维(最后一维)插入一个新的维度
w = torch.unsqueeze(x, -1)

print("Original shape:", x.shape)
print("New shape:", w.shape)
print(w)

输出:

Original shape: torch.Size([3])
New shape: torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
  • 解释
    • 使用 torch.unsqueeze(x, -1) 等同于使用 torch.unsqueeze(x, 1),在张量的最后一个维度插入了一个新的维度。
    • 结果是一个形状为 (3, 1) 的二维张量,表示张量现在有 3 行,1 列。
示例 4:将二维张量转为三维张量

如果我们有一个形状为 (2, 3) 的二维张量,并希望将其转换为三维张量(例如,插入一个维度表示批次大小),我们可以使用 unsqueeze()

# 创建一个二维张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 在第0维插入新维度
b = torch.unsqueeze(a, 0)

print("Original shape:", a.shape)
print("New shape:", b.shape)
print(b)

输出:

Original shape: torch.Size([2, 3])
New shape: torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
         [4, 5, 6]]])
  • 解释
    • 原始张量 a 的形状是 (2, 3),表示它有 2 行,3 列。
    • 使用 torch.unsqueeze(a, 0) 后,在第 0 维(最前面)插入了一个新的维度,结果是一个形状为 (1, 2, 3) 的三维张量,表示这个张量现在有 1 个批次,2 行,3 列。

总结:

  • unsqueeze() 函数用于增加张量的维度,可以通过指定维度位置插入一个新的维度(长度为 1)。
  • 常见应用:增加批次维度(例如将一维张量转换为二维张量)或调整张量形状以满足模型输入的要求。
  • 这个函数特别有用,在需要将张量的维度对齐时,或者在深度学习框架中调整数据形状时非常常见。

你可能感兴趣的:(笔记,算法,python,开发语言)