语法:
torch.flatten(input, start_dim=0, end_dim=-1)
参数:
input:输入的tensor,即将要被 “弄平” 的tensor
start_dim:“弄平” 的起始纬度,默认值为0
end_dim: “弄平” 的终止维度,默认值为-1,-1指的是最后一个纬度
作用:
从起始纬度到终止纬度,将输入的tensor弄平
示例:
1、生成一个size为(2,3,2)的全1tensor
>>> import torch
>>> a = torch.ones(2,3,2)
tensor([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]])
2、start_dim,end_dim使用默认值,就是将a整个弄平
>>> torch.flatten(a)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
3、同2,对 0,1,2纬度进行弄平,将a整个弄平
>>> torch.flatten(a, 0, 2)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
4、对 0,1纬度进行弄平,size变为(2*3, 2)=(6,2)
>>> torch.flatten(a, 0, 1)
tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
5、对 1,2纬度进行弄平,size变为(2, 3*2)=(2, 6)
torch.flatten(a, 1,2)
tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]])
>>> torch.flatten(a,1,-1) # 同上,对 1,2纬度进行弄平,size变为(2, 3*2)=(2, 6)
tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]])
# start_dim=1,没有设置end_dim,使用默认值end_dim=-1,
# 故同上,对 1,2纬度进行弄平,size变为(2, 3*2)=(2, 6)
>>> torch.flatten(a,1)
tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]])