ResNet通过前层与后层的“短路连接”(Shortcuts),加强了前后层之间的信息流通,在一定程度上缓解了梯度消失的现象,从而可以将神经网络搭建得很深。DenseNet最大化了这种前后层信息交流,通过建立前面所有层与后面层的密集连接,实现了特征在通道维度上的复用,使其可以再参数计算量最少的情况下实现比ResNet更优的性能
DenseNet的网络架构如下图所示,网络由多个Dense Block与中间的卷积池化组成,核心就在Dense Block中。Dense Block中的黑点代表一个卷积层,其中的多条黑色线代表数据的流动,每一层的输入由前面的所有卷积层输出组成。注意这里使用了通道的拼接(Concatnate)操作,而非ResNet的逐元素相加操作
DenseNet的结构有如下两个特效:
Block实现细节如下图所示,每个Block由若干个Bottleneck的卷积层组成,对应上图中的黑点。Bottleneck由BN、ReLU、1×1卷积、BN、ReLU、3×3卷积的顺序构成
关于Block,有以下4个细节需要注意:
DenseNet网络的优势主要体现在以下两个方面:
DenseNet的不足在于由于需要进行多次Concatnate操作,数据需要被复制多次,显存容易增加得很快,需要一定的显存优化技术。另外DesneNet是一种更为特殊的网络,ResNet则相对一般化一些,因此ResNet的应用更为广泛
利用PyTorch实现DenseNet的一个Block:
import torch
from torch import nn
import torch.nn.functional as F
#实现一个Bottleneck的类,初始化需要输入通道数与GrowthRate这两个参数
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
#通常1×1卷积的通道数为GrowRate的4倍
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
#将输入x同计算的结果out进行通道拼接
out = torch.cat((x, out), 1)
return out
class Denseblock(nn.Module):
def __init__(self, nChannels, growthRate, nDenseBlocks):
super(Denseblock, self).__init__()
layers = []
#将每一个Bottleneck利用nn.Sequential()整合起来,输入通道数需要线性增长
for i in range(int(nDenseBlocks)):
layers.append(Bottleneck(nChannels, growthRate))
nChannels += growthRate
self.denseblock = nn.Sequential(*layers)
def forward(self, x):
return self.denseblock(x)
终端:
>>> import torch
>>> from densenet_block import Denseblock
>>> #实例化DenseBlock,包含了6个Bottleneck
>>> denseblock = Denseblock(64, 32, 6)
>>> #查看denseblock的网络结构,由6个Bottleneck组成
>>> denseblock
Denseblock(
(denseblock): Sequential(
#第1个Bottleneck的输入通道数为64,输出固定为32
(0): Bottleneck(
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
#第2个Bottleneck的输入通道数为96,输出固定为32
(1): Bottleneck(
(bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
#第3个Bottleneck的输入通道数为128,输出固定为32
(2): Bottleneck(
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
#第4个Bottleneck的输入通道数为160,输出固定为32
(3): Bottleneck(
(bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
#第5个Bottleneck的输入通道数为192,输出固定为32
(4): Bottleneck(
(bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
#第6个Bottleneck的输入通道数为224,输出固定为32
(5): Bottleneck(
(bn1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
)
>>> input = torch.randn(1, 64, 256, 256)
>>> output = denseblock(input) #将输入传入denseblock结构中
>>> #输出通道数为224 + 32 = 64 + 32 × 6 = 256
>>> output.shape
torch.Size([1, 256, 256, 256])