python:flatten()参数详解

  • flatten()是对多维数据的降维函数。
  • flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。
  • python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
    在这里插入图片描述

比如我们随机定义一个维度为(2,3,4)的数据a

import torch
a = torch.rand(2,3,4)

a输出结果为:

python:flatten()参数详解_第1张图片
a此时的维度为(2,3,4)

flatten()和flatten(0)效果一样,a这个数据从0维展开,就是(2 ∗ 3 ∗ 4 2342∗3∗4),维度就是(24)


你可能感兴趣的:(Python,python)