densenet网络结构_CVPR2017最佳论文DenseNet(二)pytorch实现

982549697255e6bdc49d2ca7af0233e3.png

根据上文(CVPR2017最佳论文DenseNet(一)原理分析)进行理论介绍,本文使用pytorch进行了网络结构的复现。下文从包含网络块与整体网络结构两个部分,网络超参数均按照论文中所述。在进行代码实现之前,首先回顾一下将要用到的符号与网络结构图:

(a)k表示增长率,也就是growth rate,代表每个bottleneck层的输出通道数,固定不变

(b)Hl(·)表示第l层的映射函数(即卷积等操作);

e4b160c040d202b697eb12a1efe9a56e.png

1、Bottleneck Layers 瓶颈层

瓶颈层主要作用是使用1x1的卷积降低特征图通道数,这样使得后面3x3卷积的计算量的到减轻。根据论文中所述,1x1卷积的输出通道数为4k,也就是4倍的增长率。计算的顺序为BN + ReLU + Conv1x1 + BN + ReLU + Conv3x3,对应的代码实现如下:

class 

上述代码种最重要的一句是“out = torch.cat((x, out), dim=1)”,这行代码是在前向卷积计算以后调用的,也就是说当前第l层(Bottleneck层)的forward的返回值是映射函数H(·)输出的k个特征图与前面所有的特征图(共l*k个)拼接,这是个递归的过程!对应原理图中本层与后面层(后面Bottleneck层)相连的线,保证每一层计算都不会影响到上一层的特征图,使得前层特征图可以不断的以累积拼接的形式向后层传递。另外,每个Bottleneck层的输出通道数都是相同的,均为增长率k

2、Transition Layers 转换层

转换层为各个Dense Block之间的部分,也就是论文中对应的Pooling layer。计算过程为BN+ ReLU + Conv1x1 + Avg Pooling2x2,代码如下:

class 

3、制作DenseBlock

我在DenseBlock内部全部使用Bottleneck层,目的是尽可能用1x1卷积降低通道数以保证3x3卷积计算量不会太大。代码如下:

def 

创建Dense Block仅仅是一个循环过程,需要注意的是“4”,每添加一个Bottleneck层都会使得下一层的输入通道数增加k个,因为每一层输入都是前面所有层的输出。

4、FirstConv首个卷积层

这不是DenseNet提出的概念,而是为了代码清晰所以单独作为一个类来实现。FirstConv负责将输入图片从3个通道变为和自己想要的m个通道,从而输入到后面的DenseBlock层,代码如下:

class 

5、DenseNet实现

前一篇理论分析中的网络结构表格如下所示,复现中使用DenseNet-121(k=32)。对应4个Dense Block包含的Bottleneck层数分别是6、12、24和16。

densenet网络结构_CVPR2017最佳论文DenseNet(二)pytorch实现_第1张图片

DenseNet实现包含下面三个部分:

(1)首先需要使用一个卷积+池化降低特征图长宽并将通道数设定到某个值,作为后面Dense Block的输入;

(2)随后是4个Dense Block,每个Dense Block由多个Bottleneck层组成。前3个Dense Block后面都紧跟一个Transition层,且Transition层输出通道数为其输入通道数的0.5倍(即compression=0.5),(3)最后一个Dense Block后面没有Transition层,而是7x7全局均值池化层,池化层后面是用于分类的全连接层。

网络的实现中有两点需要特殊注意:第一是增长率不能太大,否则会爆显存;第二是必须预先计算好在超参数设定为具体值时对应各个Dense Block的输出通道数,并进行调试确定输出通道数是否正确。对应的代码实现如下:

class 

6、输出测试

测试上述定义的DenseNet类是否正确,可以通过检查其输出是否为预期的尺寸。比如对于b

x = torch.randn(size=(4, 3, 224, 224))
densenet = DenseNet(channels_in=3, compression=0.5, growth_rate=12, num_classes=10,num_bottleneck=[6, 12, 24, 16],
                        num_channels_before_dense=32,
                        num_dense_block=4)
out = densenet(x)

参考:

https://github.com/gpleiss/efficient_densenet_pytorch

你可能感兴趣的:(densenet网络结构,densenet论文)