【深度学习】torch.nn.Flatten()和torch.flatten()的区别

torch.nn.Flatten 和 torch.flatten两个函数的区别

在深度学习模型训练和测试之前,通常要对tensor数据进行预处理,在处理过程中,涉及到将高纬度数据降维的操作。

torch.nn.Flatten 和 torch.flatten 都被用来进行降维操作。

torch.nn.Flatten() 有如下性质:

默认参数为start_dim = 1 , end_dim = -1,即从第1维(第0维不变)开始到最后一维结束将每个batch拉伸成一维:
【深度学习】torch.nn.Flatten()和torch.flatten()的区别_第1张图片
在这里插入图片描述
当仅设置一个参数时,该参数表示 start_dim 的值,即从该维度开始到最后一个维度结束,将每个batch拉伸成一维,其余维度不变:
【深度学习】torch.nn.Flatten()和torch.flatten()的区别_第2张图片
在这里插入图片描述当设置两个参数时,两个参数分别表示开始维度和结束维度:

【深度学习】torch.nn.Flatten()和torch.flatten()的区别_第3张图片
在这里插入图片描述torch.nn.Flatten()函数官方文档:
【深度学习】torch.nn.Flatten()和torch.flatten()的区别_第4张图片
对于torch.flatten():

torch.flatten()函数默认start_dim = 0 , 其余与torch.nn.flatten()相同。

【深度学习】torch.nn.Flatten()和torch.flatten()的区别_第5张图片

torch.flatten()函数官方文档:

【深度学习】torch.nn.Flatten()和torch.flatten()的区别_第6张图片

你可能感兴趣的:(python,深度学习,神经网络,pytorch)