pytorch中张量变换函数

在PyTorch中view(), transpose()permute() 函数都是用于改变张量(Tensor)维度结构的,但它们的作用和使用场景有所不同。

  1. torch.view()
    • 功能:该函数用于将一个张量重塑为新的形状,但它必须保持原有元素数量不变。它主要用于改变张量的维度布局,而不仅仅是交换维度。
    • 用法:通常用于简化或展开张量的维度,例如将三维张量展平成一维或二维。
import torch

batch = 3
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (batch, seq_size, embed)).float()
print(x)
# tensor([[[ 7., 14.,  2., 10.,  5., 17., 11.,  7.],
#          [24.,  4., 11., 21., 16., 21., 12., 24.]],
#
#         [[14.,  1., 13.,  5.,  0., 16.,  5., 22.],
#          [ 9.,  2., 21.,  6., 15.,  1., 16., 15.]],
#
#         [[23.,  4.,  4., 16.,  1., 18.,  0., 20.],
#          [ 9.,  1.,  1.,  7., 13., 21., 12., 12.]]])

# 将后两维度张量展平,将每一词的词嵌入按行连接
z = x.view(batch, -1)
print(z)
# tensor([[7., 14., 2., 10., 5., 17., 11., 7., 24., 4., 11., 21., 16., 21., 12., 24.],
#         [14., 1., 13., 5., 0., 16., 5., 22., 9., 2., 21., 6., 15., 1., 16., 15.],
#         [23., 4., 4., 16., 1., 18., 0., 20., 9., 1., 1., 7., 13., 21., 12., 12.]])

# transformer中多头注意力机制常用,把最后一维词嵌入的维度进行两次切割
# 切割出来多余的那部分做为batch放在第一个维度上
y = x.view(batch * 2, -1, embed // 2)
print(y)
# tensor([[[ 7., 14.,  2., 10.],
#          [ 5., 17., 11.,  7.]],
#
#         [[24.,  4., 11., 21.],
#          [16., 21., 12., 24.]],
#
#         [[14.,  1., 13.,  5.],
#          [ 0., 16.,  5., 22.]],
#
#         [[ 9.,  2., 21.,  6.],
#          [15.,  1., 16., 15.]],
#
#         [[23.,  4.,  4., 16.],
#          [ 1., 18.,  0., 20.]],
#
#         [[ 9.,  1.,  1.,  7.],
#          [13., 21., 12., 12.]]])

  1. torch.transpose()
    • 功能:该函数用于交换两个指定的维度(转置),其中给定轴上的元素被互换。
    • 用法:传入两个指定的参数维度且参数循序无关
import torch

batch = 3
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (batch, seq_size, embed)).float()
print(x)
# tensor([[[ 7., 14.,  2., 10.,  5., 17., 11.,  7.],
#          [24.,  4., 11., 21., 16., 21., 12., 24.]],
#
#         [[14.,  1., 13.,  5.,  0., 16.,  5., 22.],
#          [ 9.,  2., 21.,  6., 15.,  1., 16., 15.]],
#
#         [[23.,  4.,  4., 16.,  1., 18.,  0., 20.],
#          [ 9.,  1.,  1.,  7., 13., 21., 12., 12.]]])

# 将后两维交换(转置),将每一词的词嵌入按列展示
z = x.transpose(1, 2)  # 等价 x.transpose(2, 1)
print(z)
# tensor([[[ 7., 24.],
#          [14.,  4.],
#          [ 2., 11.],
#          [10., 21.],
#          [ 5., 16.],
#          [17., 21.],
#          [11., 12.],
#          [ 7., 24.]],
#
#         [[14.,  9.],
#          [ 1.,  2.],
#          [13., 21.],
#          [ 5.,  6.],
#          [ 0., 15.],
#          [16.,  1.],
#          [ 5., 16.],
#          [22., 15.]],
#
#         [[23.,  9.],
#          [ 4.,  1.],
#          [ 4.,  1.],
#          [16.,  7.],
#          [ 1., 13.],
#          [18., 21.],
#          [ 0., 12.],
#          [20., 12.]]])

  1. torch.permute()
    • 功能:该函数允许一次性重新排列多个维度,理解成transpose的扩展。
    • 用法:传入张量的所有维度,可以同时交换任意两个及以上的维度。
import torch

batch = 3
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (batch, seq_size, embed)).float()
print(x)
# tensor([[[ 7., 14.,  2., 10.,  5., 17., 11.,  7.],
#          [24.,  4., 11., 21., 16., 21., 12., 24.]],
#
#         [[14.,  1., 13.,  5.,  0., 16.,  5., 22.],
#          [ 9.,  2., 21.,  6., 15.,  1., 16., 15.]],
#
#         [[23.,  4.,  4., 16.,  1., 18.,  0., 20.],
#          [ 9.,  1.,  1.,  7., 13., 21., 12., 12.]]])

# 将后两维重新排序
# 注意这样是报错x.permute(2, 1)或者permute(1, 2, 1)都是非法的
z = x.permute(0, 2, 1)  # 等价 x.transpose(2, 1),
# print(z)

# 如果我们想要三个维度都交换transpose是做不到的
# 至于有什么实际意义就不讨论了
y = x.permute(2, 1, 0)
print(y)
# tensor([[[ 7., 14., 23.],
#          [24.,  9.,  9.]],
#
#         [[14.,  1.,  4.],
#          [ 4.,  2.,  1.]],
#
#         [[ 2., 13.,  4.],
#          [11., 21.,  1.]],
#
#         [[10.,  5., 16.],
#          [21.,  6.,  7.]],
#
#         [[ 5.,  0.,  1.],
#          [16., 15., 13.]],
#
#         [[17., 16., 18.],
#          [21.,  1., 21.]],
#
#         [[11.,  5.,  0.],
#          [12., 16., 12.]],
#
#         [[ 7., 22., 20.],
#          [24., 15., 12.]]])

  1. torch.unsqueeze()
    • 功能:增加一个新的维度。
    • 用法:增加维度指定的位置。
import torch

seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (seq_size, embed)).float()
print(x)
# tensor([[ 7., 14.,  2., 10.,  5., 17., 11.,  7.],
#         [24.,  4., 11., 21., 16., 21., 12., 24.]])

z = x.unsqueeze(0)  # 等价 torch.unsqueeze(x, dim=0)
print(z)
# tensor([[[ 7., 14.,  2., 10.,  5., 17., 11.,  7.],
#          [24.,  4., 11., 21., 16., 21., 12., 24.]]])

总结:

  • view() 更侧重于保持数据不变的前提下改变张量的维度形状,常用于展平、重塑等操作。

  • transpose() 是特定的维度交换操作,只涉及两个维度的变换。

  • permute() 则提供了更灵活的维度重排功能,可以处理多维度情况下的整体维度顺序调整。

  • unsqueeze() 指定位置增加张量维度。

你可能感兴趣的:(pytorch,人工智能,python)