当我们在搭建网络时,tensor进入全连接层/GAP/GMP/分类器之前需要对tensor进行拉平操作,保留某个维度或者去除某个维度,本文试着总结一下常见的将tensor拉平的方法,如有问题希望大家批评指正。
在计算机视觉领域,无论是图像分类还是目标检测,CNN常被用作图片特征提取的Backbone(主干网络)。CNN经过某些卷积操作生成feature map,降低分辨率,增大通道数。在进入最后的全连接层/分类器之前时,特征信息最多,往往此时需要保留通道数而忽略图片的宽高。本文以上一篇文章的MobileNetV2为例,阐述几种tensor拉平的方法。
MobileNetV2的forward部分:
def forward(self, x):
#2,3,32,32
x = self.conv1(x)
#2,3,32,32
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = self.bottleneck4(x)
x = self.bottleneck5(x)
x = self.bottleneck6(x)
x = self.bottleneck7(x)
#2,320,4,4
x = self.conv2(x)
#2,1280,4,4
x = self.avgpool(x)
#2,1280,1,1
#tensor拉平发生的位置
x = x.view(x.size[0],-1)
#2,1280
x = self.linear(x)
#2,1000
return x
Pytorch中tensor的输入格式为[B,C,H,W],分别代表batch_size,channels,高,宽。简单回顾一下,假设输入的tensor为[2,3,32,32],经过forward到全连接层之前的tensor变为[2,1280,1,1],分辨率降低,通道数变多,我们的目的是将tensor拉平,即只要batch_size和channels,方便后续分类。此时可以采用view()操作,这也是最常见的操作。
'''
view()是根据元素总数来改变tensor形状的,即变形后的tensor元素总数不变
本例中元素总数为2*1280*1*1
x.size[0]是x的第一个维度batch_size,本例中为2,-1代表自动计算该维度
想要去掉H,W则只需指定第一个维度的batch_size自动计算第二个维度即可,因为H,W经过卷积后均为1
'''
x = x.view(x.size[0],-1)
forward中可以修改如下:
def forward(self, x):
#2,3,32,32
x = self.conv1(x)
#2,3,32,32
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = self.bottleneck4(x)
x = self.bottleneck5(x)
x = self.bottleneck6(x)
x = self.bottleneck7(x)
#2,320,4,4
x = self.conv2(x)
#2,1280,4,4
x = self.avgpool(x)
#2,1280,1,1
#tensor拉平发生的位置
#flatten的两种方式
#将第一维之后的维度合并
#x = torch.flatten(x,1)
#x = x.flatten(1)
#2,1280
x = self.linear(x)
#2,1000
return x
flatten原型如下:
flatten(input,start_dim=0,end_dim=-1)
其中input为输入的tensor,start_dim为起始维度,end_dim为终止维度。flatten的功能为将start_dim到end_dim的维度合并为一个维度。本例中将128011合并为一个维度1280
由于本例的特殊性,最后的H,W均为1,则可直接用squeeze()去掉维度为1的维度
def forward(self, x):
#2,3,32,32
x = self.conv1(x)
#2,3,32,32
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = self.bottleneck4(x)
x = self.bottleneck5(x)
x = self.bottleneck6(x)
x = self.bottleneck7(x)
#2,320,4,4
x = self.conv2(x)
#2,1280,4,4
x = self.avgpool(x)
#2,1280,1,1
#tensor拉平发生的位置
x = x.squeeze()
#2,1280
x = self.linear(x)
#2,1000
return x
未完待续……