目录
【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读
unsqueeze() 函数的作用:
语法:
unsqueeze() 操作示例:
示例 1:将一个一维张量转换为二维张量
示例 2:在最后一维插入一个新维度
示例 3:负索引插入维度
示例 4:将二维张量转为三维张量
总结:
在 PyTorch 中,
unsqueeze()
是一个非常实用的函数,用于在张量的指定位置插入一个维度。
简而言之,
unsqueeze()
通过增加一个长度为1的维度来扩展张量的维度。
unsqueeze()
函数的作用:
unsqueeze()
函数将一个张量的维度增加 1。
这个函数常用于调整张量的形状,特别是在需要将一个二维或一维张量转换为更高维度的张量时。
torch.unsqueeze(input, dim)
input
:输入张量。dim
:指定要插入新维度的位置。dim
是一个整数,表示新维度的位置,取值范围是 [-input.dim() - 1, input.dim()]
。如果 dim
为负数,它表示从最后一个维度开始计数。unsqueeze()
操作示例:假设我们有一个一维张量 [1, 2, 3]
,我们希望通过 unsqueeze()
将其转换为一个二维张量,并在第 0 维度(最前面)插入一个新的维度。
import torch
# 创建一个一维张量
x = torch.tensor([1, 2, 3])
# 在第0维插入一个新的维度
y = torch.unsqueeze(x, 0)
print("Original shape:", x.shape) # 原始张量形状
print("New shape:", y.shape) # 新张量形状
print(y)
输出:
Original shape: torch.Size([3])
New shape: torch.Size([1, 3])
tensor([[1, 2, 3]])
x
的形状是 (3)
,表示这是一个包含 3 个元素的一维张量。torch.unsqueeze(x, 0)
后,在张量的第 0 维插入了一个新的维度。结果是一个形状为 (1, 3)
的二维张量。unsqueeze(0)
会在第一个维度(最前面)插入新的维度,表示这个张量现在有 1 行,3 列。假设我们希望将张量 [1, 2, 3]
变成形状为 (3, 1)
的二维张量,我们可以在第 1 维(最后一维)插入一个新的维度。
# 在第1维插入一个新的维度
z = torch.unsqueeze(x, 1)
print("Original shape:", x.shape)
print("New shape:", z.shape)
print(z)
输出:
Original shape: torch.Size([3])
New shape: torch.Size([3, 1])
tensor([[1],
[2],
[3]])
x
的形状是 (3)
,是一个一维张量。torch.unsqueeze(x, 1)
后,在第 1 维(即最后一个维度)插入了一个新的维度。结果是一个形状为 (3, 1)
的二维张量,表示这个张量现在有 3 行,1 列。我们可以使用负数索引来指定维度的位置。负数表示从最后一个维度开始计数。
# 在倒数第一维(最后一维)插入一个新的维度
w = torch.unsqueeze(x, -1)
print("Original shape:", x.shape)
print("New shape:", w.shape)
print(w)
输出:
Original shape: torch.Size([3])
New shape: torch.Size([3, 1])
tensor([[1],
[2],
[3]])
torch.unsqueeze(x, -1)
等同于使用 torch.unsqueeze(x, 1)
,在张量的最后一个维度插入了一个新的维度。(3, 1)
的二维张量,表示张量现在有 3 行,1 列。如果我们有一个形状为 (2, 3)
的二维张量,并希望将其转换为三维张量(例如,插入一个维度表示批次大小),我们可以使用 unsqueeze()
。
# 创建一个二维张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 在第0维插入新维度
b = torch.unsqueeze(a, 0)
print("Original shape:", a.shape)
print("New shape:", b.shape)
print(b)
输出:
Original shape: torch.Size([2, 3])
New shape: torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
[4, 5, 6]]])
a
的形状是 (2, 3)
,表示它有 2 行,3 列。torch.unsqueeze(a, 0)
后,在第 0 维(最前面)插入了一个新的维度,结果是一个形状为 (1, 2, 3)
的三维张量,表示这个张量现在有 1 个批次,2 行,3 列。unsqueeze()
函数用于增加张量的维度,可以通过指定维度位置插入一个新的维度(长度为 1)。