pytorch:ASPP模块,空洞卷积

import torch.nn as nn
import torch.nn.functional as F
import torch
"""ASPP模块"""


class ASPPCV(nn.Sequential):
    """
    ASPP卷积模块的定义
    """
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPCV, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    """
    ASPP的pooling层
    """
    def __init__(self, in_channels, out_channels):  # [in_channel=out_channel=256]
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),  # [256*1*1]
            # 自适应平均池化层,只需要给定输出的特征图的尺寸(括号内数字)就好了
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):
    """
    ASPP空洞卷积块
    """
    def __init__(self, in_channels, atrous_rates):  # atrous_rates=(6, 12, 18)
        super(ASPP, self).__init__()
        out_channels = in_channels
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),  # (64-1+2*0)/1+1=64[256*64*64]
            nn.BatchNorm2d(out_channels),  # [256*64*64]
            nn.ReLU()))  # 1x1卷积
        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPCV(in_channels, out_channels, rate1))  # 3*3卷积( padding=6, dilation=6 )
        modules.append(ASPPCV(in_channels, out_channels, rate2))  # 3*3 卷积( padding=12, dilation=12 )
        modules.append(ASPPCV(in_channels, out_channels, rate3))  # 3*3 卷积( padding=18, dilation=18 )  [256*64*64]
        modules.append(ASPPPooling(in_channels, out_channels))  # 全局平均池化操作,输出尺寸为(1,1) [256*1*1]
        self.convs = nn.ModuleList(modules)
        self.project = nn.Sequential(  # 特征融合?此时输入通道是原始输入通道的5倍。输出的结果又回到原始的通道数。
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),  # [1280*64*64]
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        net = self.project(res)  # 特征融合1280——>256
        # return x + net
        return net

eg:ASPP((256),(6,12,18)),256代表的是输入特征数和输出特征数,ASPP在少量增加参数量的同时扩大网络的感受野,并且不会改变特征数量,(6,12,18)是卷积速率,

计算方式:空洞卷积率dr,k是kernel_size,感受野尺寸为【(dr-1)(k-1)+k】,带有padding的空洞卷积感受野尺寸为【2(dr-1)(k-1)+k】

eg:3x3kernel_size rate=2 padding=2  ——>size=2(2-1)(3-1)+3=7

你可能感兴趣的:(深度学习,pytorch,python)