语义分割系列20-CCNet(pytorch实现)

CCNet:《Ccnet: Criss-cross attention for semantic segmentation》

本文将介绍CCNet、Criss-cross attention原理、代码实现细节。用时约10分钟。

CCNet发布于2019 ICCV,而Criss-cross attention则是作为Non Local的改进版,主要用于减少Attention机制的计算量。


引文

在以往的语义分割工作中,存在一些大大小小的问题,主要包括感受野问题和计算量问题。

  • 为了解决感受野问题,DeepLab一系列工作主要集中在扩张卷积(空洞卷积)上,如DeepLabv2提出的ASPP模块。PSPNet则是提出了PPM模块来处理多尺度的问题,用若干个自适应池化来提取不同尺度上的上下文信息。DANet则是寻求于Attention机制,使用Spatial Attention模块来实现空间注意力机制,建立两个任意位置的空间连接信息。PSANet则是设计了双向的信息流来处理attention maps,使得每一个点信息互相相关。
  • Non Local机制或者说是Attention机制,虽然可以很大程度上解决感受野问题,但是在计算复杂度上受限严重。为了减少计算量,最简单的方法就是减小通道数、降低分辨率,但是这些方法会造成信息损失,导致模型结构降低。而EMANet则是通过EM算法来迭代计算Attention maps,通过计算一组基base来更新Attention maps,从而降低计算量。

而CCNet的提出,正是为了解决计算复杂度的问题。

模型细节

Criss-cross attention

为了解决Attention计算复杂的问题,作者提出了Criss-cross attention模块。

语义分割系列20-CCNet(pytorch实现)_第1张图片 图1 Criss-cross attention与Non Local对比

图1中包括了None-local(a)和Criss-cross attention(b)的简要示意图,不同于Non-Local一次性计算全图的attention,Criss-cross attention机制则是计算一个点的横纵位置的attention信息。 但是,我们可以猜到,如果只计算一次横纵位置的attention,则其他位置并没有被关联到,也就是这次计算的attention是局限在横纵轴位置上的,其中包括的语义信息并不丰富。为了解决这个问题,作者串行了两个Criss-cross attention模块,这样,对于一个点的位置,首先计算了他的横纵轴的attention,然后将这个信息输出后,再经过一个Criss-cross attention计算,这个点就间接的与全图位置内的任意点进行了计算。如下图2,Loop1中浅绿色方块包含了蓝色方块的内容,Loop2中的深绿色与浅绿色方块进行计算,其中包含了浅绿色+蓝色方块内容,也就是深绿色方块同时关联了浅绿色方块和蓝色方块。

语义分割系列20-CCNet(pytorch实现)_第2张图片 图2 两个位置attention信息关联

Criss-cross attention实现细节

语义分割系列20-CCNet(pytorch实现)_第3张图片 图3 Criss-cross attention实现细节

Criss-cross attention的实现方式还是基于Attention机制,首先,主干网络的输出X经过一个卷积来减少通道数,得到一个feature maps H,H∈ [C×W ×H]。接着H经过三个卷积模块分别生成Q、K和V,其中{Q, K}∈ [C′×W ×H],C′设置为C的八分之一,用于减少计算量。接着,QK通过Affinity操作计算生成A。

对于Affinity操作:在Q中的每一个位置μ,我们都可以在channel轴得到一个向量Q_\mu,同时我们可以从K中,提取与位置μ处于同一行、列的向量\Omega_{u}\in R^{(H+W-1)*C'},那么第i个位置的参数就是\Omega_{i,u}。对于Affinity计算公式:

 \large d_{i,u}=Q_u\Omega _{i,u}^T

生成的D经过Softmax激活后,得到A∈ [(H+W −1)×W ×H]。

对于生成的V∈[C×W ×H],我们同样对于每一个位置μ,可以在channel轴上得到一个向量集

\large \Phi _u∈[(H+W −1)×C],将这个向量集与生成的A相乘,完成Aggregation操作,最后再加上原始输入H,输出生成的H'。

\large H_u^{'}=\Sigma A_{i,u}\Phi _{i,u} + H_u

为了使每一个位置μ可以与任何位置对应起来,作者通过两次计算Criss-cross来完成,只需对H'再次计算criss-cross attention,输出H''。此时就有:

\large H_u^{''}=[f(A,u,\theta )+1]\cdot f(A',u,\theta)\cdot H_\theta

对应于图2的坐标,我们有:

\large H_u^{''}\leftarrow [f(A,u_x,\theta _y, \theta _x,\theta _y)\cdot f(A',u_x,u_y, u _x,\theta _y) + f(A,\theta_x,u _y, \theta _x,\theta _y)\cdot f(A',u_x,u_y, \theta _x,u _y)] \cdot H_\theta

模型结构

模型结构的设计比较简单,因为Criss-Cross attention module如Non Local一样,比较灵活,可以加在任意位置,所以这里作者像其他工作一样,简单的加在CNN的输出后面,用于处理feature maps,通过简单的上采样来完成分割任务。如上文提到,这里叠加了两个Criss-cross attention module,这个叠加的Criss-cross attention,作者称为Recurrent Criss-Cross Attention module。

语义分割系列20-CCNet(pytorch实现)_第4张图片

模型代码

backbone-ResNet50(8倍下采样)

模型选择了ResNet50,并且选择8倍下采样,在其中添加了扩张卷积。

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=1, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        return x

    def forward(self, x) :
        return self._forward_impl(x)
    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3],pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3],pretrained_path,**kwargs)

Criss-Cross attention module

这一部分是CCNet的重点,对于每一个计算步骤的size,我都在上一行中添加了注释。

import torch
import torch.nn as nn
import torch.nn.functional as F

def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)

class CrissCrossAttention(nn.Module):
    def __init__(self, in_channels):
        super(CrissCrossAttention, self).__init__()
        self.in_channels = in_channels
        self.channels = in_channels // 8
        self.ConvQuery = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
        self.ConvKey = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
        self.ConvValue = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1)

        self.SoftMax = nn.Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        b, _, h, w = x.size()

        # [b, c', h, w]
        query = self.ConvQuery(x)
        # [b, w, c', h] -> [b*w, c', h] -> [b*w, h, c']
        query_H = query.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).permute(0, 2, 1)
        # [b, h, c', w] -> [b*h, c', w] -> [b*h, w, c']
        query_W = query.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).permute(0, 2, 1)
        
        # [b, c', h, w]
        key = self.ConvKey(x)
        # [b, w, c', h] -> [b*w, c', h]
        key_H = key.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
        # [b, h, c', w] -> [b*h, c', w]
        key_W = key.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
        
        # [b, c, h, w]
        value = self.ConvValue(x)
        # [b, w, c, h] -> [b*w, c, h]
        value_H = value.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
        # [b, h, c, w] -> [b*h, c, w]
        value_W = value.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
        
        # [b*w, h, c']* [b*w, c', h] -> [b*w, h, h] -> [b, h, w, h]
        energy_H = (torch.bmm(query_H, key_H) + self.INF(b, h, w)).view(b, w, h, h).permute(0, 2, 1, 3)
        # [b*h, w, c']*[b*h, c', w] -> [b*h, w, w] -> [b, h, w, w]
        energy_W = torch.bmm(query_W, key_W).view(b, h, w, w)
        # [b, h, w, h+w]  concate channels in axis=3 

        concate = self.SoftMax(torch.cat([energy_H, energy_W], 3))
        # [b, h, w, h] -> [b, w, h, h] -> [b*w, h, h]
        attention_H = concate[:,:,:, 0:h].permute(0, 2, 1, 3).contiguous().view(b*w, h, h)
        attention_W = concate[:,:,:, h:h+w].contiguous().view(b*h, w, w)
 
        # [b*w, h, c]*[b*w, h, h] -> [b, w, c, h]
        out_H = torch.bmm(value_H, attention_H.permute(0, 2, 1)).view(b, w, -1, h).permute(0, 2, 3, 1)
        out_W = torch.bmm(value_W, attention_W.permute(0, 2, 1)).view(b, h, -1, w).permute(0, 2, 1, 3)

        return self.gamma*(out_H + out_W) + x

if __name__ == "__main__":
    model = CrissCrossAttention(512)
    x = torch.randn(2, 512, 28, 28)
    model.cuda()
    out = model(x.cuda())
    print(out.shape)

RCCAModule

RCCA Module就是几个Criss-Cross attention module叠加,只需要计算一个循环即可,为了方便,这里还集成了上采用和输出模块。

class RCCAModule(nn.Module):
    def __init__(self, recurrence = 2, in_channels = 2048, num_classes=33):
        super(RCCAModule, self).__init__()
        self.recurrence = recurrence
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.inter_channels = in_channels // 4
        self.conv_in = nn.Sequential(
            nn.Conv2d(self.in_channels, self.inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(self.inter_channels)
        )
        self.CCA = CrissCrossAttention(self.inter_channels)
        self.conv_out = nn.Sequential(
            nn.Conv2d(self.inter_channels, self.inter_channels, 3, padding=1, bias=False)
        )
        self.cls_seg = nn.Sequential(
            nn.Conv2d(self.in_channels+self.inter_channels, self.inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(self.inter_channels),
            nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
            nn.Conv2d(self.inter_channels, self.num_classes, 1)
        )

    def forward(self, x):
        # reduce channels from C to C'   2048->512
        output = self.conv_in(x)

        for i in range(self.recurrence):
            output = self.CCA(output)

        output = self.conv_out(output)
        output = self.cls_seg(torch.cat([x, output], 1))
        return output

if __name__ == "__main__":
    model = RCCAModule(in_channels=2048)
    x = torch.randn(2, 2048, 28, 28)
    model.cuda()
    out = model(x.cuda())
    print(out.shape)

CCNet

最后集成一下,完成model设置。

class CCNet(nn.Module):
    def __init__(self, num_classes):
        super(CCNet, self).__init__()
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.decode_head = RCCAModule(recurrence=2, in_channels=2048, num_classes=num_classes)
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.decode_head(x)
        return x

if __name__ == "__main__":
    model = CCNet(num_classes=2)
    x = torch.randn(2, 3, 224, 224)
    model.cuda()
    out = model(x.cuda())
    print(out.shape)

Dataset Camvid

这里在Camvid数据集上进行测试。

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
 
torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'database/camvid/camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)

模型训练

model = CCNet(num_classes=33).cuda()

from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
#损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)
#选用adam优化器来训练
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)

#训练50轮
epochs_num = 100
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
               devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    
    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(
                net, features, labels.long(), loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        scheduler.step()
        print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} ---  train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df['time'] = time_list
        df.to_excel("savefile/CCNet_camvid.xlsx")
        #----------------保存模型-------------------
        if np.mod(epoch+1, 5) == 0:
            torch.save(model.state_dict(), f'checkpoints/CCNet_{epoch+1}.pth')
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果

语义分割系列20-CCNet(pytorch实现)_第5张图片

 

 

你可能感兴趣的:(语义分割,深度学习,神经网络,人工智能,计算机视觉,pytorch)