PyTorch 中的维度操作详解

在 PyTorch 中,维度(dimension) 是描述张量形状的一种方式。维度操作是 PyTorch 中非常重要的功能,常用于调整张量的形状以适配各种计算需求。以下是常见的维度操作及其示例。


1. 维度的概念回顾

  • 一个二维张量(矩阵)的形状是 (行数, 列数)
  • 一个三维张量的形状是 (深度, 行数, 列数)
  • 维度的索引从 0 开始,最外层是 axis=0,向内依次递增。

2. 维度的操作

(1) 求和(Sum)

sum(dim) 的作用是沿着指定的维度对张量求和,并 移除该维度

示例:二维张量
import torch

tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape)  # (2, 3)
  • 沿维度 0 求和(dim=0

    sum_axis0 = tensor.sum(dim=0)
    print("沿维度 0 求和:\n", sum_axis0)
    

    输出

    沿维度 0 求和:
     tensor([5, 7, 9])
    

    解释

    • 维度 0 是行方向。
    • 对每一列的元素求和:
      • 第 0 列:1 + 4 = 5
      • 第 1 列:2 + 5 = 7
      • 第 2 列:3 + 6 = 9
  • 沿维度 1 求和(dim=1

    sum_axis1 = tensor.sum(dim=1)
    print("沿维度 1 求和:\n", sum_axis1)
    

    输出

    沿维度 1 求和:
     tensor([6, 15])
    

    解释

    • 维度 1 是列方向。
    • 对每一行的元素求和:
      • 第 0 行:1 + 2 + 3 = 6
      • 第 1 行:4 + 5 + 6 = 15
示例:三维张量
tensor = torch.tensor([[[1, 2], [3, 4]],
                       [[5, 6], [7, 8]]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape)  # (2, 2, 2)
  • 沿维度 0 求和(dim=0

    sum_axis0 = tensor.sum(dim=0)
    print("沿维度 0 求和:\n", sum_axis0)
    

    输出

    沿维度 0 求和:
     tensor([[ 6,  8],
             [10, 12]])
    

    解释

    • 维度 0 是最外层(矩阵的数量)。
    • 对两个矩阵的对应位置元素求和:
      • [1, 2] + [5, 6] = [6, 8]
      • [3, 4] + [7, 8] = [10, 12]

(2) 增加维度(Unsqueeze)

unsqueeze(dim) 的作用是在指定维度上增加一个大小为 1 的维度。

示例:二维张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape)  # (2, 3)
  • 在维度 0 增加维度

    unsqueeze_axis0 = tensor.unsqueeze(0)
    print("在维度 0 增加维度:\n", unsqueeze_axis0)
    print("形状:", unsqueeze_axis0.shape)  # (1, 2, 3)
    

    输出

    在维度 0 增加维度:
     tensor([[[1, 2, 3],
              [4, 5, 6]]])
    形状: (1, 2, 3)
    
  • 在维度 1 增加维度

    unsqueeze_axis1 = tensor.unsqueeze(1)
    print("在维度 1 增加维度:\n", unsqueeze_axis1)
    print("形状:", unsqueeze_axis1.shape)  # (2, 1, 3)
    

    输出

    在维度 1 增加维度:
     tensor([[[1, 2, 3]],
             [[4, 5, 6]]])
    形状: (2, 1, 3)
    

(3) 移除维度(Squeeze)

squeeze(dim) 的作用是移除指定维度上大小为 1 的维度。

示例:三维张量
tensor = torch.tensor([[[1, 2, 3]]])
print("原始张量:\n", tensor)
print("形状:", tensor.shape)  # (1, 1, 3)
  • 移除所有大小为 1 的维度

    squeeze_tensor = tensor.squeeze()
    print("移除所有大小为 1 的维度:\n", squeeze_tensor)
    print("形状:", squeeze_tensor.shape)  # (3,)
    

    输出

    移除所有大小为 1 的维度:
     tensor([1, 2, 3])
    形状: (3,)
    
  • 移除指定维度

    squeeze_dim0 = tensor.squeeze(0)
    print("移除维度 0:\n", squeeze_dim0)
    print("形状:", squeeze_dim0.shape)  # (1, 3)
    

    输出

    移除维度 0:
     tensor([[1, 2, 3]])
    形状: (1, 3)
    

3. 总结

  • 求和(sum:沿指定维度对元素求和,并移除该维度。
  • 增加维度(unsqueeze:在指定维度上增加一个大小为 1 的维度。
  • 移除维度(squeeze:移除指定维度上大小为 1 的维度。

通过这些操作,可以灵活调整张量的形状,使其适配各种计算需求!

你可能感兴趣的:(pytorch,人工智能,python)