论文阅读:Duplex Contextual Relation Network for Polyp Segmentation

结肠镜图像分割论文阅读

  • 论文总体架构
    • 摘要
    • 引言
    • 相关工作
    • 工作总结: 1、图像内上下文关系模块 2、图像外上下文关系模块 这两个模块也是即插即用的。
    • 模型结构
        • 先上图片
        • 内部上下文关系
      • 外部上下文关系(这个平生还是第一次见,值得重点观察)
    • 实验分析
    • 讨论
    • 厚着脸皮,要个点赞收藏,谢谢支持!!!

论文总体架构

论文名称:用于息肉分割的双重上下文关系网络(ISBI2022)
作者单位:北京邮电大学
作者名称:尹子衿等
代码地址: https://github.com/PRIS-CV/DCRNet/blob/master/lib/DCRNet.py

摘要

结肠镜检查中的息肉自动分割在结直肠癌(CRC)的早期诊断中起着关键作用。然而,息肉图像的多样性极大增加了准确分割的难度。现有的研究主要集中在学习单个图像中的上下文信息,但未能利用跨图像的息肉的同步视觉模式。本文从整个数据集的整体角度来探索上下文相关性,并提出了一个双工上下文关系网络(DCRNet)来捕获图像内和交叉图像之间上下文关系。基于上述两种相似性,每个输入区域的特征可以通过嵌入上下文区域来增强每个输入区域的特征。为了存储训练过程中先前图像嵌入的特征区域,设计了情景记忆并作为队列操作。我们在EndoScene、Kvasir-SEG和最近发布的大规模PICCOLO数据集上评估了所提出的方法。实验结果表明,我们提出的DCRNet在广泛使用的评价指标方面优于最先进的方法。

贡献
1、提出来嵌入上下文区域;
2、设计了情景记忆并作为队列操作;
3、提出了DCRNet;
4、模型在多个结肠癌数据集上的表现良好。

引言

结肠癌的诊断和治疗中,对于息肉的区域分析是非常关键的步骤,切除息肉是预防和治疗早期结肠直肠癌的直接手段。结肠镜图像能够清晰地展示出整个患者结肠部分的信息,但是对于息肉的定位分割依然存在着以下困难:1、息肉多饰多样;2、息肉和结肠粘膜之间的边界过于模糊。如图所示:
论文阅读:Duplex Contextual Relation Network for Polyp Segmentation_第1张图片
从图像中我们能够观察到,有的比较明显,像 a b,肿起来的部分就是,而d就很夸张,c很不明显,不仔细看根本看不着。


相关工作

在现有的工作中,这里简介:
1、多尺度提取特征的网络:ACSNet(MICCAI 2020),结合上下文信息和局部细节来应对息肉特征多样性的问题。
PraNet使用多尺度的特征聚合的方法,根据局部特征提取轮廓图并通过上采样依次细化分割图。
2、利用辅助信息来约束分割结果:SFANet(MICCAI 2019),利用区域边界约束,来选择特征聚合,提高分割精度。

重点: 这些工作,额,好像都是在单个图像上找特征分割,这样的话是不是涉及到一个隐性的病灶相似度,然后选取对应的分割参数??如果是这样的话,一个模型所做到的工作就是在对于明显的病灶的分割的基础上,对于不同类型的息肉图像进行相应的隐形分类,简单图像简单分,复杂图像及不明显的图像就特殊方法,很有道理!
所以本文就要提到一个机制,叫做情景记忆!

理论证明:(Content-based medical image retrieval of ct images of
liver lesions using manifold learning)已经证明了从其他图像中检索在放射学病变治疗过程中的意义。
相关成果:在度量学习中已经有用到。
所以,本文采用这种思想,从整个数据集的整体角度来探讨交叉图像和图像内的特征关联。

工作总结:
1、图像内上下文关系模块
2、图像外上下文关系模块
这两个模块也是即插即用的。

模型结构

先上图片

论文阅读:Duplex Contextual Relation Network for Polyp Segmentation_第2张图片
首先看到网络框架图,它由三部分组成,编码器、解码器、底部信息处理模块。
编码解码器本文用到的是基于ResNet34的UNet,这里不再赘述。直接看重头戏!

内部上下文关系

class PAM_Module(Module):
    """ Position attention module"""
    #Ref from SAGAN
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim
        self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = Parameter(torch.zeros(1))

        self.softmax = Softmax(dim=-1)
    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out

这一段代码,作者在里面写的备注还是非常详细的,这个东东的作用就是建立当前图像中所有像素点之间的关系,然后将这种关系与输入相乘,从而得到加权的效果!当然,残差结构一直是保留项目,嗯,就是这样的。

外部上下文关系(这个平生还是第一次见,值得重点观察)

class DCRNet(ResNet34Unet):
    def __init__(self,
                 bank_size=20,
                 num_classes=1,
                 num_channels=3,
                 is_deconv=False,
                 decoder_kernel_size=3,
                 pretrained=True,
                 feat_channels=512
                 ):
        super().__init__(num_classes=1,
                 num_channels=3,
                 is_deconv=False,
                 decoder_kernel_size=3,
                 pretrained=True)
        
        self.bank_size = bank_size
        self.register_buffer("bank_ptr", torch.zeros(1, dtype=torch.long))  # memory bank pointer
        self.register_buffer("bank", torch.zeros(self.bank_size, feat_channels, num_classes))  # memory bank
        self.bank_full = False
        
        # =====Attentive Cross Image Interaction==== #
        self.feat_channels = feat_channels
        self.L = nn.Conv2d(feat_channels, num_classes, 1)
        self.X = conv2d(feat_channels, 512, 3)
        self.phi = conv1d(512, 256)
        self.psi = conv1d(512, 256)
        self.delta = conv1d(512, 256)
        self.rho = conv1d(256, 512)
        self.g = conv2d(512 + 512, 512, 1)
        # =========Dual Attention========== #
        self.sa_head = PAM_Module(feat_channels)
        #=========Attention Fusion=========#
        self.fusion = nn.Conv2d(feat_channels, feat_channels, 1)
    #==Initiate the pointer of bank buffer==#
    def init(self):
        self.bank_ptr[0] = 0
        self.bank_full = False
        
    @torch.no_grad() #这句很重要!!!!
    def update_bank(self, x):
        ptr = int(self.bank_ptr)
        batch_size = x.shape[0]
        vacancy = self.bank_size - ptr
        if batch_size >= vacancy:
            self.bank_full = True
        pos = min(batch_size, vacancy)
        self.bank[ptr:ptr+pos] = x[0:pos].clone()
        # update pointer
        ptr = (ptr + pos) % self.bank_size
        self.bank_ptr[0] = ptr
        
    def down(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)        
        return e4, e3, e2, e1
    
    def up(self, feat, e3, e2, e1, x):
        center = self.center(feat)
        d4 = self.decoder4(torch.cat([center, e3], 1))
        d3 = self.decoder3(torch.cat([d4, e2], 1))
        d2 = self.decoder2(torch.cat([d3, e1], 1))
        d1 = self.decoder1(torch.cat([d2, x], 1))
 
        f1 = self.finalconv1(d1)
        f2 = self.finalconv2(d2)
        f3 = self.finalconv3(d3)
        f4 = self.finalconv4(d4)
                
        f4 = F.interpolate(f4, scale_factor=8, mode='bilinear', align_corners=True)
        f3 = F.interpolate(f3, scale_factor=4, mode='bilinear', align_corners=True)
        f2 = F.interpolate(f2, scale_factor=2, mode='bilinear', align_corners=True)
        
        return f4, f3, f2, f1
   
    def region_representation(self, input):
        X = self.X(input)
        L = self.L(input)
        aux_out = L
        batch, n_class, height, width = L.shape
        l_flat = L.view(batch, n_class, -1)
        # M = B * N * HW
        M = torch.softmax(l_flat, -1)
        channel = X.shape[1]
        # X_flat = B * C * HW
        X_flat = X.view(batch, channel, -1)
        # f_k = B * C * N
        f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)
        return aux_out, f_k, X_flat, X
    
    def attentive_interaction(self, bank, X_flat, X):
        batch, n_class, height, width = X.shape
        # query = S * C
        query = self.phi(bank).squeeze(dim=2)
        # key: = B * C * HW
        key = self.psi(X_flat)
        # logit = HW * S * B (cross image relation)
        logit = torch.matmul(query, key).transpose(0,2)
        # attn = HW * S * B
        attn = torch.softmax(logit, 2) ##softmax维度要正确
        
        # delta = S * C
        delta = self.delta(bank).squeeze(dim=2)
        # attn_sum = B * C * HW
        attn_sum = torch.matmul(attn.transpose(1,2), delta).transpose(1,2)
        # x_obj = B * C * H * W
        X_obj = self.rho(attn_sum).view(batch, -1, height, width)

        concat = torch.cat([X, X_obj], 1)
        out = self.g(concat)
        return out
            
    def forward(self, x, flag='train'):
        batch_size = x.shape[0]
        #=== Stem ===#
        x = self.firstconv(x)
        x = self.firstbn(x)
        x = self.firstrelu(x)
        x_ = self.firstmaxpool(x)
 
        #=== Encoder ===#
        e4, e3, e2, e1  = self.down(x_)        
        #=== Attentive Cross Image Interaction ===#
        aux_out, patch, feats_flat, feats = self.region_representation(e4)
        if flag == 'train':
            self.update_bank(patch)
            ptr = int(self.bank_ptr)
            if self.bank_full == True:
                feature_aug = self.attentive_interaction(self.bank, feats_flat, feats)
            else:
                feature_aug = self.attentive_interaction(self.bank[0:ptr], feats_flat, feats)
        elif flag == 'test':
            feature_aug = self.attentive_interaction(patch, feats_flat, feats)
        #=== Dual Attention ===#
        sa_feat = self.sa_head(e4)
        #=== Fusion ===#
        feats = sa_feat + feature_aug
        #=== Decoder ===#
        f4, f3, f2, f1 = self.up(feats, e3, e2, e1, x)
        aux_out = F.interpolate(aux_out, scale_factor=32, mode='bilinear', align_corners=True)
        return aux_out, f4, f3, f2, f1


实验分析

实验部分主要包含以下几个方面:

数据集名称 图像数量 train valid test
EndoScene 912 548 182 182
Kvasir-SEG 1000 600 200 200
PICCOLO 3433 2203 897 333

设备 学习率 epoches batchsize memory size
NVIDIA RTX 2080Ti 1e-4 150 4 20(Kvasir) / 40(E & P)

论文阅读:Duplex Contextual Relation Network for Polyp Segmentation_第3张图片
论文阅读:Duplex Contextual Relation Network for Polyp Segmentation_第4张图片
从可视化和表格数据上,我们能够看出本文模型的有效性!

论文阅读:Duplex Contextual Relation Network for Polyp Segmentation_第5张图片

对于这两个经典模型,有着不错的提高,说明了本模型的设计和内外上下文推理体系的合理性。

讨论

本文最大的亮点应该是外部memory 的设定,对于整个模型的体系架构,我们应当学习到这种内部隐性的分类思想和理念,所谓的外部上下文关系模块的机理也是如此!

厚着脸皮,要个点赞收藏,谢谢支持!!!

你可能感兴趣的:(笔记,计算机视觉,人工智能,深度学习)