语义分割系列21-ANNNet(pytorch实现)

ANNNet(Asymmetric Non-local Neural Network):《Asymmetric Non-Local Neural Networks for Semantic Segmentation》
发布于2019 ICCV。
语义分割系列21-ANNNet(pytorch实现)_第1张图片

引文

之前一直困扰语义分割的两片乌云:

  • 感受野大小不足:一些研究表明,建立长距离依赖可以有效的提升分割效果。解决长距离依赖的问题,衍生出了许多之前的工作,包括如:叠加多个卷积层来扩张感受野;DeepLabv2的ASPP,通过扩张卷积来扩大感受野;RepLKNet利用超大卷积核来实现扩大感受野;PSPNet通过PPM池化金字塔,来获取多个尺度上的上下文信息;Non-Local Net、DANet等通过Non-Local也就是Attention机制来建立长距离依赖等等。
  • 计算量过大:大卷积核往往计算量十分大,尽管2022年RepLKNet通过优化卷积计算来解决了计算量的问题,但在2019年,这个问题还是个难题,所以超大卷积核的方案还没被真正看好。而叠加卷积层的方式会退化,ASPP的扩张卷积则会产生网格效应;Non-Local的方式虽然能够建立空间任意两点的信息,在扩大感受野的能力上十分优秀,但是计算量十分大,导致模型运行很慢。同年减少Non-Local计算量的工作比如CCNet,利用Criss-Cross attention模块来减少计算量。可见,减少Non-Local计算量的工作是当年(2019)的主流工作。

同样为了减少Non-Local的计算量,本文提出了Asymmetric Non-local Block,也就是非对称的Non-Local,如下图。

不同于原始的Non-Local操作,本文提出的Asymmetric Non-local Block在计算Key和Value上通过sample的方式,减少了Key和Value的大小,从而在Matmul和Softmax操作上大大减少了计算量(这两个操作正好是Non-Local操作耗费时间较长的操作模块,说明作者是有目的并非盲目的Sample),而这里的Sample方式会在下文中详细描述。
语义分割系列21-ANNNet(pytorch实现)_第2张图片

图1 两种Non-Local结构与计算量对比

本文的亮点:

  1. 通过Pyramid Pooling Module(PPM)来实现sample,减少Non-Local的计算量。
  2. 提出AFNP和APNB模块用来减少计算开销和融合特征,提升分割准确率。

模型、论文细节

总体模型

语义分割系列21-ANNNet(pytorch实现)_第3张图片

图2 Asymmetric Non-local Neural Network模型结构

ANN结构如图2,其中主干网络主要为ResNet,ANN网络的亮点主要集中在AFNB和APNB结构。

AFNB(Asymmetric Fusion Non-local Block)

语义分割系列21-ANNNet(pytorch实现)_第4张图片

图3 Asymmetric Fusion Non-local Block(AFNB)结构

AFNB结构与Non-Local结构的区别在于,AFNB结构计算Key和Value时,通过Pyramid Pooling Module(PPM)(图3右)进行sample。而PPM结构最早由PSPNet提出,在这里,作者将Key和Value的特征图进行池化采样,池化大小为[1, 3, 6, 8],输出为:1×1,3×3,6×6,8×8,展平链接后大小正好为110。
举个例子:一个feature map:[8, 256, 56, 56] 对应 [batch, channels, h, w],如果不进行sample,那么key和value的大小为 [8, 256, 56*56],如果进行sample,大小则为 [8, 256, 110]。我们可以发现,这里面的计算量倍数就是 56 ∗ 56 110 = 28 \frac{56*56}{110}=28 1105656=28,当然这里的计算量并不一定是线性的,但是对于query需要连续与key和value进行matmul运算,对于这一部分的计算量减小是显著的。
对于AFNB模块的计算公式,我们有主干网络stage4的输出 X l X_l Xl和stage5的输出 X h X_h Xh,对于Query、Key、Value和输出Out的计算:
q u e r y = f q ( X h ) query = f_q(X_h) query=fq(Xh)
k e y = Φ s a m p l e ( f k ( X h ) ) key = \Phi _{sample}(f_k(X_h)) key=Φsample(fk(Xh))
v a l u e = Φ s a m p l e f v ( X h ) value = \Phi _{sample}f_v(X_h) value=Φsamplefv(Xh)
O u t = f o u t ( S o f t M a x ( q u e r y ⊙ k e y ) ⊙ ( v a l u e ) ) Out =f_{out}(SoftMax(query\odot key) \odot (value)) Out=fout(SoftMax(querykey)(value))

对应代码就是:

class AFNPBlock(nn.Module):
    def __init__(self, in_channels, key_channels, value_channels, pool_sizes=[1,3,6,8]):
        super(AFNPBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.key_channels = key_channels
        self.value_channels = value_channels

        # query 接受的是stage5的Xh 所以这里的是in_channels=2048
        self.Conv_query = nn.Sequential(
            nn.Conv2d(self.in_channels, self.key_channels, 1),
            nn.BatchNorm2d(self.key_channels),
            nn.ReLU()
        )   
        # key 和 value 接受的是stage4的输出Xl 这里的in_channels//2为1024
        self.Conv_key = nn.Sequential(
            nn.Conv2d(self.in_channels // 2, self.key_channels, 1),
            nn.BatchNorm2d(self.key_channels),
            nn.ReLU()
        )
        self.Conv_value = nn.Conv2d(self.in_channels // 2, self.value_channels, 1)

        self.ConvOut = nn.Conv2d(self.value_channels, self.out_channels, 1)
        self.ppm = PPMModule(pool_sizes)
        # 给ConvOut初始化为0
        nn.init.constant_(self.ConvOut.weight, 0)
        nn.init.constant_(self.ConvOut.bias, 0)

    def forward(self, low_feats, high_feats):
        # low_feats = stage4   high_feats = stage5
        b, c, h, w = high_feats.size()

        # value = [batch, -1, value_channels] // 这里-1由pool_sizes决定,目前的设置为110=1+3*3+6*6+8*8
        value = self.ppm(self.Conv_value(low_feats)).permute(0, 2, 1)
        # batch = [batch, key_channels, -1] // 这里-1由pool_sizes决定,目前的设置为110=1+3*3+6*6+8*8
        key = self.ppm(self.Conv_key(low_feats))
        # query = [batch, key_channels, h*w] -> [batch, h*w, key_channels]
        query = self.Conv_query(high_feats).view(b, self.key_channels, -1).permute(0, 2, 1)

        # Concat_QK = [batch, h*w, 110]
        Concat_QK = torch.matmul(query, key)
        Concat_QK = (self.key_channels ** -.5) * Concat_QK
        Concat_QK = F.softmax(Concat_QK, dim=-1)

        # Aggregate_QKV = [batch, h*w, Value_channels]
        Aggregate_QKV = torch.matmul(Concat_QK, value)
        # Aggregate_QKV = [batch, value_channels, h*w]
        Aggregate_QKV = Aggregate_QKV.permute(0, 2, 1).contiguous()
        # Aggregate_QKV = [batch, value_channels, h*w] -> [batch, value_channels, h, w]
        Aggregate_QKV = Aggregate_QKV.view(b, self.value_channels, *high_feats.size()[2:])
        # Conv out
        Aggregate_QKV = self.ConvOut(Aggregate_QKV)

        return Aggregate_QKV


if __name__ == "__main__":
    low_feat = torch.randn((2, 1024, 64, 64))
    highfeat = torch.randn((2, 2048, 64, 64))
    AFNB = AFNPBlock(in_channels=2048, value_channels=256, key_channels=256)
    out = AFNB(low_feat, highfeat)
    print("AFNP output.shape:",out.shape)

APNB(Asymmetric Pyramid Non-local Block)

语义分割系列21-ANNNet(pytorch实现)_第5张图片

图4 Asymmetrical Pyramid Non-local Block(APNB)结构

APNB的结构同样类似于Non-Local,只不过这里的图中计算Key的步骤被作者省略了没画出来。与AFNB相同,Value计算同样通过一个卷积和一个Pyramid Pooling进行sample;需要注意的是这里计算Query和Key的卷积操作权重共享,也就是初步计算出来的Query和Key是等同的,接着Key再输入到Pyramid Pooling中进行sample。对应公式为:
q u e r y = f q ( Y F ) query = f_q(Y_F) query=fq(YF)
k e y = Φ s a m p l e ( f q ( Y F ) ) key = \Phi _{sample}(f_q(Y_F)) key=Φsample(fq(YF))
v a l u e = Φ s a m p l e f v ( Y F ) value = \Phi _{sample}f_v(Y_F) value=Φsamplefv(YF)
O u t = f o u t ( S o f t M a x ( q u e r y ⊙ k e y ) ⊙ ( v a l u e ) ) Out =f_{out}(SoftMax(query\odot key) \odot (value)) Out=fout(SoftMax(querykey)(value))
代码层面就是:

class APNBBlock(nn.Module):
    def __init__(self, in_channels, out_channels, key_channels, value_channels, pool_sizes=[1, 3, 6, 8]):
        super(APNBBlock, self).__init__()

        # Generally speaking, here, in_channels==out_channels and key_channels==value_channles
        self.in_channels = in_channels
        self.out_channles = out_channels
        self.value_channels = value_channels
        self.key_channels = key_channels
        self.pool_sizes = pool_sizes

        self.Conv_Key = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
                      kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.key_channels),
            nn.ReLU()
        )
        # 这里Conv_Query 和 Conv_Key权重共享,也就是计算出来的query和key是等同的
        self.Conv_Query = self.Conv_Key
        
        self.Conv_Value = nn.Conv2d(self.in_channels, self.key_channels, 1)
        self.Conv_Out = nn.Conv2d(self.value_channels, self.out_channles, 1)
        nn.init.constant_(self.Conv_Out.weight, 0)
        nn.init.constant_(self.Conv_Out.bias, 0)
        self.ppm = PPMModule(pool_sizes=self.pool_sizes)

    def forward(self, x):
        b, _, h, w = x.size()
        
        # query = [batch, key_channels, -1 -> h*w] -> [batch, h*w, key_channels]
        value = self.ppm(self.Conv_Value(x)).permute(0, 2, 1)
        # query = [batch, key_channels, -1 -> h*w] -> [batch, h*w, key_channels]
        query = self.Conv_Query(x).view(b, self.key_channels, -1).permute(0, 2, 1)
        # key = [batch, key_channels, 110]  where 110 = sum([s*2 for s in pool_sizes]) 1 + 3*2 + 6*2 + 8*2
        key = self.ppm(self.Conv_Key(x))

        # Concat_QK = [batch, h*w, 110]
        Concat_QK = torch.matmul(query, key)
        Concat_QK = (self.key_channels ** -.5) * Concat_QK
        Concat_QK = F.softmax(Concat_QK, dim=-1)

        # Aggregate_QKV = [batch, h*w, Value_channels]
        Aggregate_QKV = torch.matmul(Concat_QK, value)
        # Aggregate_QKV = [batch, value_channels, h*w]
        Aggregate_QKV = Aggregate_QKV.permute(0, 2, 1).contiguous()
        # Aggregate_QKV = [batch, value_channels, h*w] -> [batch, value_channels, h, w]
        Aggregate_QKV = Aggregate_QKV.view(b, self.value_channels, *x.size()[2:])
        # Conv out
        Aggregate_QKV = self.Conv_Out(Aggregate_QKV)

        return Aggregate_QKV

if __name__ == "__main__":
    x = torch.randn((2, 512, 64, 64))
    APNB = APNBBlock(in_channels=512, out_channels=512, value_channels=256, key_channels=256)
    out = APNB(x)
    print("APNB output.shape:",out.shape)

总结

ANNNet通过使用Pyramid Pooling的方式,对Non-Local中的Key和Value进行采样,以减少计算量。分别提出了AFNB和APNB两个非对称的Non-Local结构,分别用于特征融合和提高分割准确率。

模型复现

backbone-ResNet50(8倍下采样)

需要注意的是,在这里的ResNet50中,最后两个stage没有进行下采样,也就是最后两个stage特征图大小是相同的。

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)

        x = self.layer3(x)
        out.append(x)
        x = self.layer4(x)
        out.append(x)
        return out
        
    def forward(self, x) :
        return self._forward_impl(x)
    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3],pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3],pretrained_path,**kwargs)

ANNNet

Pyramid Pooling Module

先实现sample功能

import torch
import torch.nn as nn
import torch.nn.functional as F
class PPMModule(nn.ModuleList):
    def __init__(self, pool_sizes=[1,3,6,8]):
        super(PPMModule, self).__init__()
        for pool_size in pool_sizes:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_size)
                )
            )
    def forward(self, x):
        out = []
        b, c, _, _ = x.size()
        for index, module in enumerate(self):
            out.append(module(x))
        # 最后输出时将其合并
        return torch.cat([output.view(b, c, -1) for output in out], -1) 

if __name__ == "__main__":
    input = torch.randn((2, 256, 32, 32))
    ppmModule = PPMModule()
    out = ppmModule(input)
    print(out.size())

AFNB

class AFNPBlock(nn.Module):
    def __init__(self, in_channels, key_channels, value_channels, pool_sizes=[1,3,6,8]):
        super(AFNPBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.key_channels = key_channels
        self.value_channels = value_channels

        # query 接受的是stage5的Xh 所以这里的是in_channels=2048
        self.Conv_query = nn.Sequential(
            nn.Conv2d(self.in_channels, self.key_channels, 1),
            nn.BatchNorm2d(self.key_channels),
            nn.ReLU()
        )   
        # key 和 value 接受的是stage4的输出Xl 这里的in_channels//2为1024
        self.Conv_key = nn.Sequential(
            nn.Conv2d(self.in_channels // 2, self.key_channels, 1),
            nn.BatchNorm2d(self.key_channels),
            nn.ReLU()
        )
        self.Conv_value = nn.Conv2d(self.in_channels // 2, self.value_channels, 1)

        self.ConvOut = nn.Conv2d(self.value_channels, self.out_channels, 1)
        self.ppm = PPMModule(pool_sizes)
        # 给ConvOut初始化为0
        nn.init.constant_(self.ConvOut.weight, 0)
        nn.init.constant_(self.ConvOut.bias, 0)

    def forward(self, low_feats, high_feats):
        # low_feats = stage4   high_feats = stage5
        b, c, h, w = high_feats.size()

        # value = [batch, -1, value_channels] // 这里-1由pool_sizes决定,目前的设置为110=1+3*3+6*6+8*8
        value = self.ppm(self.Conv_value(low_feats)).permute(0, 2, 1)
        # batch = [batch, key_channels, -1] // 这里-1由pool_sizes决定,目前的设置为110=1+3*3+6*6+8*8
        key = self.ppm(self.Conv_key(low_feats))
        # query = [batch, key_channels, h*w] -> [batch, h*w, key_channels]
        query = self.Conv_query(high_feats).view(b, self.key_channels, -1).permute(0, 2, 1)

        # Concat_QK = [batch, h*w, 110]
        Concat_QK = torch.matmul(query, key)
        Concat_QK = (self.key_channels ** -.5) * Concat_QK
        Concat_QK = F.softmax(Concat_QK, dim=-1)

        # Aggregate_QKV = [batch, h*w, Value_channels]
        Aggregate_QKV = torch.matmul(Concat_QK, value)
        # Aggregate_QKV = [batch, value_channels, h*w]
        Aggregate_QKV = Aggregate_QKV.permute(0, 2, 1).contiguous()
        # Aggregate_QKV = [batch, value_channels, h*w] -> [batch, value_channels, h, w]
        Aggregate_QKV = Aggregate_QKV.view(b, self.value_channels, *high_feats.size()[2:])
        # Conv out
        Aggregate_QKV = self.ConvOut(Aggregate_QKV)

        return Aggregate_QKV


if __name__ == "__main__":
    low_feat = torch.randn((2, 1024, 64, 64))
    highfeat = torch.randn((2, 2048, 64, 64))
    AFNB = AFNPBlock(in_channels=2048, value_channels=256, key_channels=256)
    out = AFNB(low_feat, highfeat)
    print("AFNP output.shape:",out.shape)

APNB

class APNBBlock(nn.Module):
    def __init__(self, in_channels, out_channels, key_channels, value_channels, pool_sizes=[1, 3, 6, 8]):
        super(APNBBlock, self).__init__()

        # Generally speaking, here, in_channels==out_channels and key_channels==value_channles
        self.in_channels = in_channels
        self.out_channles = out_channels
        self.value_channels = value_channels
        self.key_channels = key_channels
        self.pool_sizes = pool_sizes

        self.Conv_Key = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
                      kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.key_channels),
            nn.ReLU()
        )
        # 这里Conv_Query 和 Conv_Key权重共享,也就是计算出来的query和key是等同的
        self.Conv_Query = self.Conv_Key
        
        self.Conv_Value = nn.Conv2d(self.in_channels, self.value_channels, 1)
        self.Conv_Out = nn.Conv2d(self.value_channels, self.out_channles, 1)
        nn.init.constant_(self.Conv_Out.weight, 0)
        nn.init.constant_(self.Conv_Out.bias, 0)
        self.ppm = PPMModule(pool_sizes=self.pool_sizes)

    def forward(self, x):
        b, _, h, w = x.size()
        
        # query = [batch, key_channels, -1 -> h*w] -> [batch, h*w, key_channels]
        value = self.ppm(self.Conv_Value(x)).permute(0, 2, 1)
        # query = [batch, key_channels, -1 -> h*w] -> [batch, h*w, key_channels]
        query = self.Conv_Query(x).view(b, self.key_channels, -1).permute(0, 2, 1)
        # key = [batch, key_channels, 110]  where 110 = sum([s*2 for s in pool_sizes]) 1 + 3*2 + 6*2 + 8*2
        key = self.ppm(self.Conv_Key(x))

        # Concat_QK = [batch, h*w, 110]
        Concat_QK = torch.matmul(query, key)
        Concat_QK = (self.key_channels ** -.5) * Concat_QK
        Concat_QK = F.softmax(Concat_QK, dim=-1)

        # Aggregate_QKV = [batch, h*w, Value_channels]
        Aggregate_QKV = torch.matmul(Concat_QK, value)
        # Aggregate_QKV = [batch, value_channels, h*w]
        Aggregate_QKV = Aggregate_QKV.permute(0, 2, 1).contiguous()
        # Aggregate_QKV = [batch, value_channels, h*w] -> [batch, value_channels, h, w]
        Aggregate_QKV = Aggregate_QKV.view(b, self.value_channels, *x.size()[2:])
        # Conv out
        Aggregate_QKV = self.Conv_Out(Aggregate_QKV)

        return Aggregate_QKV

if __name__ == "__main__":
    x = torch.randn((2, 512, 64, 64))
    APNB = APNBBlock(in_channels=512, out_channels=512, value_channels=256, key_channels=256)
    out = APNB(x)
    print("APNB output.shape:",out.shape)

ANNNet

import torch
import torch.nn as nn
import torch.nn.functional as F

class asymmetric_non_local_network(nn.Sequential):
    def __init__(self, num_classes=2, aux_loss=False):
        super(asymmetric_non_local_network, self).__init__()
        self.num_classes = num_classes

        # 是否需要辅助的Loss分支
        self.aux_loss = aux_loss
        
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        
        # AFNB and APNB
        self.fusion = AFNPBlock(in_channels=2048, value_channels=256, key_channels=256, pool_sizes=[1,3,6,8])

        self.APNB = APNBBlock(in_channels=512, out_channels=512, value_channels=256, key_channels=256, pool_sizes=[1,3,6,8])
        # extra added layers
        self.context = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            self.APNB
        )
        self.cls = nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        self.dsn = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout2d(0.05),
            nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        )

       
    def forward(self, x_):
        x = self.backbone(x_)
        aux_x = self.dsn(x[-2])
        x = self.fusion(x[-2], x[-1])
        x = self.context(x)
        x = self.cls(x)
        aux_x = F.interpolate(aux_x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
        x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
        if self.aux_loss:
            return aux_x, x
        return x

if __name__ == "__main__":
    x = torch.randn((2, 3, 224, 224))
    ANNNet = asymmetric_non_local_network(num_classes=2)
    out = ANNNet(x)
    print("ANNNet auxoutput.shape:",out[0].shape)
    print("ANNNet output.shape:",out[1].shape)

Train

dataset camvid

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
 
torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'database/camvid/camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)

start train

model = asymmetric_non_local_network(num_classes=33).cuda()

from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
#损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)
#选用adam优化器来训练
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)

#训练50轮
epochs_num = 100
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
               devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    
    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(
                net, features, labels.long(), loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        scheduler.step()
        print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} ---  train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df['time'] = time_list
        df.to_excel("savefile/ANNNet_camvid.xlsx")
        #----------------保存模型-------------------
        if np.mod(epoch+1, 5) == 0:
            torch.save(model.state_dict(), f'checkpoints/ANNNet_{epoch+1}.pth')
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果

语义分割系列21-ANNNet(pytorch实现)_第6张图片

你可能感兴趣的:(语义分割,pytorch,深度学习,cnn,计算机视觉,人工智能)