torch.flatten()函数详解

自己的理解:

介绍torch.flatten()函数的具体使用方法
 1、首先创建一个三维张量
 2、调用torch.flatten()函数
import torch

x = torch.randn(2, 3, 4)
print(x.shape)

# torch.Size([2, 3, 4])
# 函数详解:
A = torch.flatten(x)  # 降维成一维向量
print(A.shape)  # torch.Size([24])
print(A)
# tensor([ 0.5366,  0.5488,  0.4033, -0.5649, -0.9119,  0.8464, -2.6698,  0.9435,
#          1.4485,  0.6482, -0.3760,  0.4114,  0.1044,  0.8057, -0.6402,  0.4294,
#          0.4673, -3.0244,  0.4310, -0.2473,  0.9410,  0.1142,  1.8234,  0.9855])


B = torch.flatten(x, 1)  # 以行降维--二维
print(B.shape)  # torch.Size([2, 12])
print(B)
# tensor([[-0.6268,  2.0879, -0.3395,  0.7372,  1.2479,  0.7701, -1.0685, -0.1118,
#           0.8185, -0.7564, -0.1037,  0.6884],
#         [ 0.0925, -0.9300, -1.2214,  0.1166,  1.7271,  0.8715, -0.0598,  0.5371,
#           0.3556, -0.7636,  0.4855,  0.0844]])

你可能感兴趣的:(琐碎的小知识,pytorch,深度学习,机器学习)