torch.
flatten
(input, start_dim=0, end_dim=- 1) → Tensorinput (Tensor) – the input tensor.
start_dim (int) – the first dim to flatten
end_dim (int) – the last dim to flatten
例如数据维度[2,3,4,4]
如果flatten(Input, start_dim = 1),则表示从第一维开始打平所有数据直到最后一维,则最终的数据被打平为两组。
如果flatten(Input, start_dim = 0),则表示从第0维开始打平所有数据直到最后一维,则最终的数据被打平为一组。
x = torch.randint(0,5,[2,3,4,4])
x
Out[31]:
tensor([[[[1, 2, 4, 4],
[1, 3, 2, 0],
[1, 2, 1, 0],
[2, 0, 4, 2]],
[[2, 0, 1, 3],
[3, 2, 0, 4],
[4, 0, 4, 1],
[0, 1, 1, 2]],
[[0, 2, 2, 1],
[2, 3, 1, 3],
[1, 4, 3, 0],
[1, 4, 0, 3]]],
[[[0, 0, 0, 4],
[2, 1, 1, 4],
[2, 3, 2, 3],
[0, 1, 0, 0]],
[[3, 0, 2, 0],
[2, 0, 0, 3],
[1, 1, 2, 2],
[2, 3, 3, 3]],
[[2, 0, 1, 4],
[4, 3, 0, 1],
[3, 2, 1, 4],
[3, 0, 0, 0]]]])
torch.flatten(x,1)
Out[32]:
tensor([[1, 2, 4, 4, 1, 3, 2, 0, 1, 2, 1, 0, 2, 0, 4, 2, 2, 0, 1, 3, 3, 2, 0, 4,
4, 0, 4, 1, 0, 1, 1, 2, 0, 2, 2, 1, 2, 3, 1, 3, 1, 4, 3, 0, 1, 4, 0, 3],
[0, 0, 0, 4, 2, 1, 1, 4, 2, 3, 2, 3, 0, 1, 0, 0, 3, 0, 2, 0, 2, 0, 0, 3,
1, 1, 2, 2, 2, 3, 3, 3, 2, 0, 1, 4, 4, 3, 0, 1, 3, 2, 1, 4, 3, 0, 0, 0]])
torch.flatten(x,1,-1)
Out[33]:
tensor([[1, 2, 4, 4, 1, 3, 2, 0, 1, 2, 1, 0, 2, 0, 4, 2, 2, 0, 1, 3, 3, 2, 0, 4,
4, 0, 4, 1, 0, 1, 1, 2, 0, 2, 2, 1, 2, 3, 1, 3, 1, 4, 3, 0, 1, 4, 0, 3],
[0, 0, 0, 4, 2, 1, 1, 4, 2, 3, 2, 3, 0, 1, 0, 0, 3, 0, 2, 0, 2, 0, 0, 3,
1, 1, 2, 2, 2, 3, 3, 3, 2, 0, 1, 4, 4, 3, 0, 1, 3, 2, 1, 4, 3, 0, 0, 0]])
torch.flatten(x,0)
Out[34]:
tensor([1, 2, 4, 4, 1, 3, 2, 0, 1, 2, 1, 0, 2, 0, 4, 2, 2, 0, 1, 3, 3, 2, 0, 4,
4, 0, 4, 1, 0, 1, 1, 2, 0, 2, 2, 1, 2, 3, 1, 3, 1, 4, 3, 0, 1, 4, 0, 3,
0, 0, 0, 4, 2, 1, 1, 4, 2, 3, 2, 3, 0, 1, 0, 0, 3, 0, 2, 0, 2, 0, 0, 3,
1, 1, 2, 2, 2, 3, 3, 3, 2, 0, 1, 4, 4, 3, 0, 1, 3, 2, 1, 4, 3, 0, 0, 0])
torch.flatten(x,0,-1)
Out[37]:
tensor([1, 2, 4, 4, 1, 3, 2, 0, 1, 2, 1, 0, 2, 0, 4, 2, 2, 0, 1, 3, 3, 2, 0, 4,
4, 0, 4, 1, 0, 1, 1, 2, 0, 2, 2, 1, 2, 3, 1, 3, 1, 4, 3, 0, 1, 4, 0, 3,
0, 0, 0, 4, 2, 1, 1, 4, 2, 3, 2, 3, 0, 1, 0, 0, 3, 0, 2, 0, 2, 0, 0, 3,
1, 1, 2, 2, 2, 3, 3, 3, 2, 0, 1, 4, 4, 3, 0, 1, 3, 2, 1, 4, 3, 0, 0, 0])