CNN记录】pytorch中flatten函数

pytorch原型

torch.flatten(input, start_dim=0, end_dim=- 1)

作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理,

参数:

start_dim:开始的维度

end_dim:终止的维度,-1为最后一个轴

默认值时展平为1维

例子

1、默认参数

input = torch.randn(2, 3, 4, 5)
output = torch.flatten(input)
输出维:torch.Size([120])

2、设置参数

input = torch.randn(2, 3, 4, 5)

output = torch.flatten(input,1)
输出shape为:torch.Size([2, 60])

output = torch.flatten(input,1,2)
输出shape为:torch.Size([2, 12, 5])

你可能感兴趣的:(DL,cnn,pytorch)