ASPP(Atrous Spatial Pyramid Pooling)

受到SPP的启发,语义分割模型DeepLabv2中提出了ASPP模块,该模块使用具有不同采样率的多个并行空洞卷积层。为每个采样率提取的特征在单独的分支中进一步处理,并融合以生成最终结果。该模块通过不同的空洞rate构建不同感受野的卷积核,用来获取多尺度物体信息,具体结构如下图所示:
ASPP(Atrous Spatial Pyramid Pooling)_第1张图片
ASPP是由空洞卷积(Atrous/Dilated Convolution)组成。如果想要对图片提取的特征具有较大的感受野,并且又想让特征图的分辨率不下降太多(分辨率损失太多会丢失许多关于图像边界的细节信息),这两个是矛盾的,想要获取较大感受野需要用较大的卷积核或池化时采用较大的strid,对于前者计算量太大,后者会损失分辨率。而空洞卷积就是用来解决这个矛盾的。即可让其获得较大感受野,又可让分辨率不损失太多。空洞卷积如下图:
ASPP(Atrous Spatial Pyramid Pooling)_第2张图片
(a)是rate=1的空洞卷积,卷积核的感受野为3×3,其实就是普通的卷积。
(b)是rate=2的空洞卷积,卷积核的感受野为7x7
(c)是rate=4的空洞卷积,卷积核的感受野为15x15空洞卷积感受野的计算
空洞卷积感受野的大小分两种情况:
(1)正常的空洞卷积:
若空洞卷积率为dilate rate
则感受野尺寸= ( d i l a t e r a t e − 1 ) ∗ ( k − 1 ) + k (dilate rate-1)(k-1)+k(dilaterate−1)∗(k−1)+k ( 其中 k为卷积核大小)
(2)padding的空洞卷积:
若空洞卷积率为dilate rate
则感受野尺寸=2 ( d i l a t e r a t e − 1 ) ∗ ( k − 1 ) + k 2(dilate rate-1)
(k-1)+k2(dilaterate−1)∗(k−1)+k ( 其中 k为卷积核大小)
ASPP 代码:

class ASPP(nn.Module):
    def __init__(self, in_channel=512, depth=256):
        super(ASPP,self).__init__()
        # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True)
        self.mean = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = nn.Conv2d(in_channel, depth, 1, 1)
        # k=1 s=1 no pad
        self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1)
        self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18)
 
        self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1)
 
    def forward(self, x):
        size = x.shape[2:]
 
        image_features = self.mean(x)
        image_features = self.conv(image_features)
        image_features = F.upsample(image_features, size=size, mode='bilinear')
 
        atrous_block1 = self.atrous_block1(x)
 
        atrous_block6 = self.atrous_block6(x)
 
        atrous_block12 = self.atrous_block12(x)
 
        atrous_block18 = self.atrous_block18(x)
 
        net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,
                                              atrous_block12, atrous_block18], dim=1))
        return net


你可能感兴趣的:(yolo,深度学习,计算机视觉,cnn)