【pytorch进阶】| 各类张量形状变化函数总结对比分析,view,reshape,pernute,transpose,squeeze,unsqueeze

文章目录

  • 1 view()函数
    • 1.1 基本用法
  • 2 view_as()函数
  • 3 reshape()函数
  • 4 permute()函数
  • 5 transpose() 函数
  • 6 squeeze()函数 和 unsqueeze()函数

1 view()函数

1.1 基本用法

view是将一个张量改变形状

函数原型

torch.Tensor.view(*shape) → Tensor

其中参数shape 可以是一个整数元组,或者是一个 系列整数

示例:两种不同参数比较

import torch

# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 参数用整数元组,变形为2x6的张量
y = x.view((2, 6))
print(y)
# Output:
# tensor([[ 1,  2,  3,  4,  5,  6],
#         [ 7,  8,  9, 10, 11, 12]])

# 参数用系列整数值,将其变形为1x12的张量
z = x.view(1, 12)
print(z)
# Output:
# tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

要求 变化后元素总的个数和变化前相同,对应上面的例子,变化前后都得是12个元素~(3,4)(2,6)(1,12)

注意上面的x,y,z都是共享底层内存的,怎么理解呢?x,y,z本质还是一个东西,y,z并不是x的副本

就是只要改变x,y,z中的其中一个,其他的张量都会受到影响改变

比如如下

# 修改视图中的元素,原始张量也会受到影响
y[0, 0] = 99
print(x)
# Output:
# tensor([[99,  2,  3,  4],
#         [ 5,  6,  7,  8],
#         [ 9, 10, 11, 12]])
print(y)
# Output:
# tensor([[99,  2,  3,  4,  5,  6],
#         [ 7,  8,  9, 10, 11, 12]])

print(z)
# Output:
# tensor([[99,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

还有一种常见的写法 -1,意味着这个值不是固定的,取决于其他维度,保证乘积不变即可

比如如下

import torch

# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

y=x.view(-1,4)
print(y.shape)
z=x.view(6,-1)
print(z.shape)

输出

torch.Size([3, 4])
torch.Size([6, 2])

2 view_as()函数

view 函数,不需要指定形状,只需要指定要保持的那个对应的张量即可

view_as 可以将一个张量的形状改为与另外一个张量相同

import torch

# 创建一个形状为(4, 2)的张量
x = torch.randn(4, 2)
print(x.shape)  # 输出:torch.Size([4, 2])

# 创建一个形状为(2, 2, 2)的张量
y = torch.randn(2, 2, 2)
print(y.shape)  # 输出:torch.Size([2, 2, 2])

# 使用view_as方法将x的形状改变为与y相同的形状
x = x.view_as(y)
print(x.shape)  # 输出:torch.Size([2, 2, 2])

注意区别

view方法需要你明确指定新的形状。例如,如果你有一个形状为(4, 2)的张量,你可以使用view(2, 4)来将其形状改变为(2, 4)

view_as方法则需要一个目标张量,它会将原始张量的形状改变为与目标张量相同的形状。

相同点是生成的 新的张量和原来张量都是共享底层内存的

3 reshape()函数

reshape使用整体和view差不多

reshape和view,大概率情况下会共享底层内存,但是在不连续的张量情况下(不连续发生在切片或者转置的时候),这时候会建立新的副本,这时候必须用reshape

例子

import torch

# 创建一个不连续的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

strided_tensor = tensor[:, ::2]  # 通过切片操作创建不连续的张量

# 使用 reshape 函数
reshaped_tensor =strided_tensor.reshape(4,1)


reshaped_tensor[0, 0] = 0
print("Original Tensor:", tensor)
print("Strided Tensor:", strided_tensor)
print("Reshaped Tensor:", reshaped_tensor)

输出

Original Tensor: tensor([[1, 2, 3],
        [4, 5, 6]])
Strided Tensor: tensor([[1, 3],
        [4, 6]])
Reshaped Tensor: tensor([[0],
        [3],
        [4],
        [6]])

如果这时候继续用view

会报如下错

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

4 permute()函数

这段代码是在使用PyTorch处理张量时对张量的维度进行重新排序的操作。

X.permute(1,0,2)是将X的维度进行重新排序。permute方法接收一组维度索引,然后按照这个索引的顺序重新排列张量的维度。

例如,如果X是一个三维张量,其维度是(batch_size, seq_length, feature_size),那么X.permute(1,0,2)将返回一个张量(也是共享底层内存)其维度是(seq_length, batch_size, feature_size)

这种操作在处理序列数据时非常常见,因为某些模型(如RNN、LSTM、GRU等)在输入数据时,需要序列长度(seq_length)在前,批量大小(batch_size)在后。所以,我们通常会使用permute或者transpose方法来调整维度的顺序。

示例如下

import torch

# 创建一个3维张量
x = torch.randn(2, 3, 4)

# 使用permute进行维度的置换
y = x.permute(1, 0, 2)
print(y.size())  # Output: torch.Size([3, 2, 4])

# 可以使用负数表示从最后一个维度开始的相对索引
z = x.permute(2, 1, 0)
print(z.size())  # Output: torch.Size([4, 3, 2])

5 transpose() 函数

在PyTorch中,transpose()函数用于交换张量(tensor)的维度。该函数返回一个新的张量,其维度顺序是原始张量维度的重新排列。

函数签名如下:

torch.transpose(input, dim0, dim1) -> Tensor
  • input: 输入的张量。
  • dim0: 第一个维度的索引。
  • dim1: 第二个维度的索引。

这个函数将input张量的dim0dim1两个维度进行交换。例如,如果input张量的形状是(a, b, c),并且你使用transpose(input, 0, 1),则返回的张量的形状将是(b, a, c),即交换了第一个和第二个维度。

以下是一个简单的示例:

import torch

# 创建一个3x4的张量
x = torch.rand((3, 4))

# 使用transpose函数交换维度
y = torch.transpose(x, 0, 1)

print("原始张量:", x)
print("交换维度后的张量:", y)

请注意,transpose()函数生成的张量并不会和原始张量共享内存,并不会修改原始张量,而是返回一个新的张量副本。如果你希望在原地操作(修改原始张量),可以使用transpose_()方法:

# 在原地操作,修改原始张量
x.transpose_(0, 1)
print("原地操作后的张量:", x)

这里的下划线表示原地操作。

6 squeeze()函数 和 unsqueeze()函数

在PyTorch中,squeezeunsqueeze是用于操作张量形状的函数,用于增加或减少维度。

  1. squeeze函数:

    • torch.squeeze(input, dim=None, out=None)函数用于删除张量中维度为1的轴。如果指定了dim参数,则只会在指定轴上删除大小为1的维度,否则会删除所有大小为1的维度。
    • 参数:
      • input: 输入的张量。
      • dim (可选): 要挤压的维度,如果指定,则只删除指定维度上的大小为1的轴。
      • out (可选): 输出张量,如果指定,则将结果存储在此张量中。
    • 返回值:挤压后的张量。

    示例:

    import torch
    
    x = torch.randn(1, 3, 1, 4)
    y = torch.squeeze(x)  # 在所有大小为1的维度上进行挤压
    z = torch.squeeze(x, dim=2)  # 只在维度2上挤压大小为1的轴
    
    print(x.shape)  # 输出: torch.Size([1, 3, 1, 4])
    print(y.shape)  # 输出: torch.Size([3, 4])
    print(z.shape)  # 输出: torch.Size([1, 3, 4])
    
  2. unsqueeze函数:

    • torch.unsqueeze(input, dim)函数用于在张量的指定位置插入维度为1的轴。
    • 参数:
      • input: 输入的张量。
      • dim: 插入维度为1的轴的位置。
    • 返回值:插入维度为1的轴后的张量。

    示例:

    import torch
    
    x = torch.randn(3, 4)
    y = torch.unsqueeze(x, dim=1)  # 在维度1上插入大小为1的轴
    
    print(x.shape)  # 输出: torch.Size([3, 4])
    print(y.shape)  # 输出: torch.Size([3, 1, 4])
    

增加维度还可以通过None的方式增加

import torch

# 二维张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 在第一个轴上增加维度
y = x[:, None, :]
# 或者使用 torch.unsqueeze
# y = torch.unsqueeze(x, dim=1)

print(x.shape)  # 输出: torch.Size([2, 3])
print(y.shape)  # 输出: torch.Size([2, 1, 3])

你可能感兴趣的:(Python学习,pytorch,人工智能,深度学习)