torch.flatten()的用法

torch.flatten()的用法

语法:

torch.flatten(input, start_dim=0, end_dim=-1)

参数:
input:输入的tensor,即将要被 “弄平” 的tensor
start_dim:“弄平” 的起始纬度,默认值为0
end_dim: “弄平” 的终止维度,默认值为-1,-1指的是最后一个纬度

作用:
从起始纬度到终止纬度,将输入的tensor弄平

示例:

1、生成一个size为(2,3,2)的全1tensor

>>> import torch
>>> a = torch.ones(2,3,2)  
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],
        [[1., 1.],
         [1., 1.],
         [1., 1.]]])

2、start_dim,end_dim使用默认值,就是将a整个弄平

>>> torch.flatten(a) 
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

3、同2,对 0,1,2纬度进行弄平,将a整个弄平

>>> torch.flatten(a, 0, 2) 
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

4、对 0,1纬度进行弄平,size变为(2*3, 2)=(6,2)

>>> torch.flatten(a, 0, 1)  
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

5、对 1,2纬度进行弄平,size变为(2, 3*2)=(2, 6)

torch.flatten(a, 1,2)  
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

>>> torch.flatten(a,1,-1)  # 同上,对 1,2纬度进行弄平,size变为(2, 3*2)=(2, 6)
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
        
# start_dim=1,没有设置end_dim,使用默认值end_dim=-1, 
# 故同上,对 1,2纬度进行弄平,size变为(2, 3*2)=(2, 6)
>>> torch.flatten(a,1)  
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

你可能感兴趣的:(pytorch深度学习,python,深度学习,人工智能)