Python中的flatten函数

通俗易懂的讲解就是他把指定维度后的维度合并到一起例如:

x = torch.randn(2,3,3)
x.flatten(1) 

展开后的维度为2*9

如果是flatten(0)即为18

如果是flatten(2)即在最后一个维度后展开,因为2已经为最后一维,故维度不变,依然为2*3*3

你可能感兴趣的:(python,python)