pytorch中flatten函数
torch.flatten()
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
>>> t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]])
>>> torch.flatten(t)
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> torch.flatten(t, start_dim=1)
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
>>> torch.flatten(t, start_dim=1).size()
torch.Size([3, 4])
>>> torch.flatten(t, start_dim=0, end_dim=1)
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10],
[11, 12]])
>>> torch.flatten(t, start_dim=0, end_dim=1).size()
torch.Size([6, 2])
torch.nn.Flatten()
Class torch.nn.Flatten(start_dim=1, end_dim=-1)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
from .module import Module
[docs]class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
Examples::
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
"""
__constants__ = ['start_dim', 'end_dim']
def __init__(self, start_dim=1, end_dim=-1):
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input):
return input.flatten(self.start_dim, self.end_dim)
torch.Tensor.flatten()
和torch.flatten()一样