pytorch中torch.flatten与torch.nn.flatten用法讲解
x = torch.ones(2, 2, 2, 2)
F = torch.nn.Flatten()
y = F(x)
print(y)
print(y.shape)
>>tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.]])
>>torch.Size([2, 8])
x = torch.ones(2, 2, 2, 2)
F = torch.nn.Flatten(2)
y = F(x)
print(y)
print(y.shape)
>>tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
>>torch.Size([2, 2, 4])
x = torch.ones(2, 2, 2, 2)
F = torch.nn.Flatten(1, 2)
y = F(x)
print(y)
print(y.shape)
>>tensor([[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]])
>>torch.Size([2, 4, 2])
t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.shape)
>>torch.Size([2, 2, 2])
print(torch.flatten(t))
>>tensor([1, 2, 3, 4, 5, 6, 7, 8])
print(torch.flatten(t, 1))
>>tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
print(torch.flatten(t, 0, 1).shape)
>>torch.Size([4, 2])
t = torch.tensor(1)
print("before flatten:")
print(t)
print(t.shape)
>>before flatten:
tensor(1)
torch.Size([])
print("\n")
print("after flatten:")
print(torch.flatten(t))
print(torch.flatten(t).shape)
>>after flatten:
tensor([1])
torch.Size([1])