torch之torch.flatten()

  • input: 输入,类型为Tensor。
  • start_dim: 推平的起始维度。
  • end_dim: 推平的结束维度。
import torch

a = torch.ones(2,3,4,5)

b = torch.flatten(a,start_dim=0,end_dim=2)
# 从0维开始往后推,推到第2维。所以最后应该是:(2*3*4,5)
print(b.shape)

b = torch.flatten(a,end_dim=2)
# 默认为0
print(b.shape)

b = torch.flatten(a,start_dim=-1)
# 从最后一维往后退,不变
print(b.shape)

b = torch.flatten(a,end_dim=-1)
# 推到最后一维,展平
print(b.shape)

Result:

torch.Size([24, 5])
torch.Size([24, 5])
torch.Size([2, 3, 4, 5])
torch.Size([120])

你可能感兴趣的:(#,pytorch,pytorch)