(a) 正常卷积:输入特征 Input feature,kernel = 3,stride = 1,pad = 1,输出特征 Output feature。
(b) 空洞卷积:与图 (a) 不同之处在于 pad = 2,同时引入一个 rate = 2(表示卷积核中参数间间隔的超参)。
还可通过下图进一步理解普通卷积与空洞卷积的区别,其中 hole size 即为上图中的 rate。
正常卷积:蓝色为输入,绿色为输出,移动的阴影为卷积核(kernel = 2, stride = 1,pad = 0);
空洞卷积:蓝色为输入,绿色为输出,移动的阴影为卷积核(kernel = 3, stride = 1, pad = 0, rate = 1);
以上演示动图来源于此,对理解卷积操作有很大的帮助。在实际中,空洞卷积一般有两种实现方式:(1)卷积核填充 0;(2)输入等间隔采样。
一般认为图片中相邻的像素点存在信息冗余,故而空洞卷积具备以下两个优势:
(1) 扩大感受野:传统的下采样虽可增加感受野,但会降低空间分辨率。而使用空洞卷积能够在扩大感受野的同时,保证分辨率。这十分适用于检测、分割任务中,感受野的增大可检测、分割大的目标,高分辨率则可精确定位目标。
(2) 捕获多尺度上下文信息:空洞卷积中参数 dilation rate 表明在卷积核中填充 (dilation rate-1) 个 0。设置不同 dilation rate 给网络带来不同的感受野,即获取了多尺度信息。
空洞卷积得到的某一层的结果中,邻近的像素是从相互独立的子集中卷积得到的,相互之间缺少依赖,故而空洞卷积也存在不足:
(1) 局部信息丢失:由于空洞卷积的计算方式类似于棋盘格式,某一层得到的卷积结果,来自上一层的独立的集合,没有相互依赖,因此该层的卷积结果之间没有相关性,即局部信息丢失;
(2) 远距离获取的信息没有相关性:由于空洞卷积稀疏的采样输入信号,使得远距离卷积得到的信息之间没有相关性。
上图即为 ASPP 模块示意:对 Input Feature Map 以不同采样率的空洞卷积并行采样;然后将得到的结果 concat 到一起,扩大通道数;最后通过 1 × 1 1 \times 1 1×1 的卷积将通道数降低到预期的数值。相当于以多个比例捕捉图像的上下文。
上图为添加 ASPP 模块后的网络结构,将 Block3 的输出输入到 ASPP,经过多尺度的空洞卷积采样后经过池化操作,然后由 1 × 1 1 \times 1 1×1 卷积将通道数降低至预期值。
一个没有 BN 层的 PyTorch 实现的 ASPP 代码(DeepLabv3 的 ASPP 中加入了 BN 层)如下:
#without bn version
class ASPP(nn.Module):
def __init__(self, in_channel=512, depth=256):
super(ASPP,self).__init__()
self.mean = nn.AdaptiveAvgPool2d((1, 1)) #(1,1)means ouput_dim
self.conv = nn.Conv2d(in_channel, depth, 1, 1)
self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1)
self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6)
self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12)
self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18)
self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1)
def forward(self, x):
size = x.shape[2:]
image_features = self.mean(x)
image_features = self.conv(image_features)
image_features = F.upsample(image_features, size=size, mode='bilinear')
atrous_block1 = self.atrous_block1(x)
atrous_block6 = self.atrous_block6(x)
atrous_block12 = self.atrous_block12(x)
atrous_block18 = self.atrous_block18(x)
net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,
atrous_block12, atrous_block18], dim=1))
return net
【参考】