自己的理解:
介绍torch.flatten()函数的具体使用方法 1、首先创建一个三维张量 2、调用torch.flatten()函数
import torch
x = torch.randn(2, 3, 4)
print(x.shape)
# torch.Size([2, 3, 4])
# 函数详解:
A = torch.flatten(x) # 降维成一维向量
print(A.shape) # torch.Size([24])
print(A)
# tensor([ 0.5366, 0.5488, 0.4033, -0.5649, -0.9119, 0.8464, -2.6698, 0.9435,
# 1.4485, 0.6482, -0.3760, 0.4114, 0.1044, 0.8057, -0.6402, 0.4294,
# 0.4673, -3.0244, 0.4310, -0.2473, 0.9410, 0.1142, 1.8234, 0.9855])
B = torch.flatten(x, 1) # 以行降维--二维
print(B.shape) # torch.Size([2, 12])
print(B)
# tensor([[-0.6268, 2.0879, -0.3395, 0.7372, 1.2479, 0.7701, -1.0685, -0.1118,
# 0.8185, -0.7564, -0.1037, 0.6884],
# [ 0.0925, -0.9300, -1.2214, 0.1166, 1.7271, 0.8715, -0.0598, 0.5371,
# 0.3556, -0.7636, 0.4855, 0.0844]])