Pytorch阅读文档之flatten函数

pytorch中flatten函数

Pytorch阅读文档之flatten函数_第1张图片

torch.flatten()

#展平一个连续范围的维度,输出类型为Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
# Parameters:input (Tensor) – 输入为Tensor
#start_dim (int) – 展平的开始维度
#end_dim (int) – 展平的最后维度
#example
#一个3x2x2的三维张量
>>> t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]],
                  [[9, 10],
                       [11, 12]]])
#当开始维度为0,最后维度为-1,展开为一维
>>> torch.flatten(t)
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
#当开始维度为0,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
>>> torch.flatten(t, start_dim=1)
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
>>> torch.flatten(t, start_dim=1).size()
torch.Size([3, 4])
#下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
#前面的就会合并
>>> torch.flatten(t, start_dim=0, end_dim=1)
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
>>> torch.flatten(t, start_dim=0, end_dim=1).size()
torch.Size([6, 2])

torch.nn.Flatten()

Class torch.nn.Flatten(start_dim=1, end_dim=-1)
#Flattens a contiguous range of dims into a tensor. 
#For use with Sequential. :
#param start_dim: first dim to flatten (default = 1). 
#param end_dim: last dim to flatten (default = -1).
#能力有限,个人认为是用于卷积中的
#Shape:
#Input: (N, *dims)(N,∗dims)
#Output: (N, \prod *dims)(N,∏∗dims) (for the default case).
#官方example
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
#源代码为 TORCH.NN.MODULES.FLATTEN
from .module import Module

[docs]class Flatten(Module):
    r"""
    Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
    Args:
        start_dim: first dim to flatten (default = 1).
        end_dim: last dim to flatten (default = -1).

    Shape:
        - Input: :math:`(N, *dims)`
        - Output: :math:`(N, \prod *dims)` (for the default case).


    Examples::
        >>> m = nn.Sequential(
        >>>     nn.Conv2d(1, 32, 5, 1, 1),
        >>>     nn.Flatten()
        >>> )
    """
    __constants__ = ['start_dim', 'end_dim']

    def __init__(self, start_dim=1, end_dim=-1):
        super(Flatten, self).__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, input):
        return input.flatten(self.start_dim, self.end_dim)

torch.Tensor.flatten()

和torch.flatten()一样

你可能感兴趣的:(Pytorch,深度学习框架,人工智能,pytorch,flatten)