Pytorch中将tensor拉平的几种方法

文章目录

  • 前言
  • 一、卷积神经网络提取特征的流程
  • 二、几种常见方法
    • 1.view():元素总数不变改变形状
    • 2.flatten():将指定维度合并为一个维度
    • 3.squeeze():去掉维度数为1的维度


前言

当我们在搭建网络时,tensor进入全连接层/GAP/GMP/分类器之前需要对tensor进行拉平操作,保留某个维度或者去除某个维度,本文试着总结一下常见的将tensor拉平的方法,如有问题希望大家批评指正。


一、卷积神经网络提取特征的流程

在计算机视觉领域,无论是图像分类还是目标检测,CNN常被用作图片特征提取的Backbone(主干网络)。CNN经过某些卷积操作生成feature map,降低分辨率,增大通道数。在进入最后的全连接层/分类器之前时,特征信息最多,往往此时需要保留通道数而忽略图片的宽高。本文以上一篇文章的MobileNetV2为例,阐述几种tensor拉平的方法。

二、几种常见方法

1.view():元素总数不变改变形状

MobileNetV2的forward部分:

 def forward(self, x):
        #2,3,32,32
        x = self.conv1(x)
        #2,3,32,32
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        #2,320,4,4
        x = self.conv2(x)
        #2,1280,4,4
        x = self.avgpool(x)
        #2,1280,1,1
        #tensor拉平发生的位置
        x = x.view(x.size[0],-1)
        #2,1280
        x = self.linear(x)
        #2,1000
        return x

Pytorch中tensor的输入格式为[B,C,H,W],分别代表batch_size,channels,高,宽。简单回顾一下,假设输入的tensor为[2,3,32,32],经过forward到全连接层之前的tensor变为[2,1280,1,1],分辨率降低,通道数变多,我们的目的是将tensor拉平,即只要batch_size和channels,方便后续分类。此时可以采用view()操作,这也是最常见的操作。

'''
view()是根据元素总数来改变tensor形状的,即变形后的tensor元素总数不变
本例中元素总数为2*1280*1*1
x.size[0]是x的第一个维度batch_size,本例中为2-1代表自动计算该维度
想要去掉H,W则只需指定第一个维度的batch_size自动计算第二个维度即可,因为H,W经过卷积后均为1
'''

x = x.view(x.size[0],-1)

2.flatten():将指定维度合并为一个维度

forward中可以修改如下:

 def forward(self, x):
        #2,3,32,32
        x = self.conv1(x)
        #2,3,32,32
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        #2,320,4,4
        x = self.conv2(x)
        #2,1280,4,4
        x = self.avgpool(x)
        #2,1280,1,1
        #tensor拉平发生的位置
        #flatten的两种方式
        #将第一维之后的维度合并
        #x = torch.flatten(x,1)
        #x = x.flatten(1)
        #2,1280
        x = self.linear(x)
        #2,1000
        return x

flatten原型如下:

flatten(input,start_dim=0,end_dim=-1)

其中input为输入的tensor,start_dim为起始维度,end_dim为终止维度。flatten的功能为将start_dim到end_dim的维度合并为一个维度。本例中将128011合并为一个维度1280

3.squeeze():去掉维度数为1的维度

由于本例的特殊性,最后的H,W均为1,则可直接用squeeze()去掉维度为1的维度

 def forward(self, x):
        #2,3,32,32
        x = self.conv1(x)
        #2,3,32,32
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        #2,320,4,4
        x = self.conv2(x)
        #2,1280,4,4
        x = self.avgpool(x)
        #2,1280,1,1
        #tensor拉平发生的位置
        x = x.squeeze()
        #2,1280
        x = self.linear(x)
        #2,1000
        return x

未完待续……

你可能感兴趣的:(小技巧汇总专栏,pytorch,神经网络,深度学习,计算机视觉,卷积)