eniops库中pack函数使用方法

pack

就是打包。

举个例子

import torch
import einops

# 创建输入张量
x = torch.randn(1, 6, 1, 2)  # 形状 (1, 6, 1, 2)

# 使用 pack 打包,注意输入必须是数组,所以这里要加一个[]
flatten, ps = einops.pack([x], 'h * d')

print("x shape:", x.shape)        # 输出: torch.Size([1, 6, 1, 2])
print("flatten shape:", flatten.shape)  # 输出: torch.Size([1, 6, 2])
print("ps:", ps)  # 输出: (6, 1)

模式字符串的含义, ‘h * d’ 是输出模式,表示:

  • h:保留第 0 维(大小为 1)。

  • *:将剩余的维度(第 1 维和第 2 维)展平为一个新的维度。

  • d:保留第 3 维(大小为 2)。

输出结果

  • flatten:这是打包后的张量。根据模式 ‘h * d’,flatten 的形状为 (h, w * c, d),即 (1, 6 * 1, 2) = (1, 6, 2)。
    它保留了第 0 维和第 3 维,同时将第 1 维和第 2 维展平。

  • ps:这是 pack 函数的附加输出,表示打包的形状信息。它是一个元组,记录了展平前的维度信息。对于 x 的形状 (1, 6, 1, 2),ps 的值可能是 (6, 1),表示第 1 维和第 2 维的原始形状。

如果不使用pack,则可以这样实现

import torch

# 创建输入张量
x = torch.randn(1, 6, 1, 2)  # 形状 (1, 6, 1, 2)

# 将第 1 维和第 2 维展平
x_reshaped = x.reshape(1, 6 * 1, 2)  # 形状 (1, 6, 2)

# 如果需要 ps 信息
ps = (6, 1)

print("x shape:", x.shape)        # 输出: torch.Size([1, 6, 1, 2])
print("x_reshaped shape:", x_reshaped.shape)  # 输出: torch.Size([1, 6, 2])
print("ps:", ps)  # 输出: (6, 1)

你可能感兴趣的:(python,深度学习,pytorch)