【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation

Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation, IPMI 2023

解读:IPMI 2023 港科大陈浩团队新作 | CTO: 重新思考边界检测在医学图像分割中的作用 (qq.com)

论文: https://arxiv.org/abs/2305.00678

代码: https://github.com/xiaofang007/CTO

介绍

本文提出一种新颖的网络架构CTO,即ConvolutionTransformer 和 Operator,通过结合卷积神经网络、视觉 Transformer 和显式边界检测操作,实现高精度的图像分割,并在准确性和效率之间保持最佳平衡。

CTO 遵循标准的编码器-解码器分割范式,其中编码器网络采用流行的 CNN 骨干结构来捕捉局部语义信息,并使用轻量级的 ViT 辅助网络来整合远距离依赖关系。为了增强边界的学习能力,本文进一步提出了一种基于边界引导的解码器网络,利用专用边界检测操作得到的边界掩模作为显式监督,引导解码学习过程。

Convolution, Transformer, and Operator (CTO)

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第1张图片

CTO 遵循编码器-解码器范式,并采用跳跃连接将来自编码器的低级特征聚合到解码器中。其中编码器网络由主流的 CNN 和辅助 ViT 组成。解码器网络则采用边界检测运算符来指导其学习过程。

  • 双流编码器,它结合了卷积神经网络和轻量级视觉 Transformer,分别捕捉图像局部特征依赖和图像块之间的远程特征依赖。
  • 运算符引导的解码器,它使用边界检测运算符(例如Sobel)通过生成的边界掩模来指导学习过程,整个模型以端到端的方式进行训练。

Dual-Stream Encoder

CTO 首先构建一个卷积流,选择Res2Net作为骨干网络,以捕捉局部特征依赖关系。

CTO使用一个基于轻量级Vision Transformer 的辅助流,捕捉不同图像块间的远程依赖关系。具体而言,它由多个并行的轻量级 Transformer 块组成,这些块接收不同尺度的特征块作为输入。所有的 Transformer 块共享相似的结构,包括块嵌入层和 Transformer 编码层。

LightViT 的块嵌入层用于将输入的特征块转换为嵌入向量,将空间维度转换为序列维度。Transformer 编码层用于对特征块进行自注意力机制的建模,以捕捉不同特征块之间的长程依赖关系。通过在 Transformer 模块中引入自注意力机制,LightViT 可以有效地对特征块之间的相互作用进行建模,从而提取图像的全局上下文信息。

Boundary-Guided Decoder

边界引导的解码器使用梯度运算符模块来提取前景对象的边界信息。然后,通过边界优化模块,将边界增强特征与多级编码器的特征进行整合,旨在同时在特征空间中表征类内和类间的一致性,丰富特征的表征能力。这种方法能够使解码器在生成分割结果时更好地利用边界信息,从而产生更准确的分割结果。

Boundary Enhanced Module (BEM)

边界优化模块使用高级特征和低级特征作为输入,提取边界信息并过滤掉与边界无关的信息。在水平方向Gx和垂直方向Gy上应用Sobel算子来获得梯度图。具体而言,本文采用两个3*3的参数固定卷积,并应用步长为1的卷积操作。这两个卷积定义为:

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第2张图片

然后,将这两个卷积应用于输入特征图,得到梯度图Mx和My。接下来,梯度图通过 sigmoid 函数进行归一化,然后与输入特征图融合,得到增强边缘特征图Fe:

其中,圈号表示逐元素相乘,\sigma 表示 sigmoid 函数,Mxy是将Mx和My沿通道维度进行拼接。然后,我们便可以直接使用简单的堆叠卷积层将边缘增强特征图进行融合。最后,输出特征图受到GT 边界图的监督,从而消除了物体内部的边缘特征,产生边界增强特征。

Boundary Inject Module (BIM)

通过 BEM 得到的边界增强特征可以作为先验知识,改善编码器生成的特征的图像表示能力。BIM,引入了双路径边界融合方案,促进前景和背景特征的表示能力。具体而言,BIM 接收两个输入:边界增强特征与来自编码器网络的对应特征的通道级连接,以及前一解码器层的特征。然后,这两个输入被馈送到 BIM 中,其中包含两个独立的路径,分别用于促进前景和背景的特征表示。

  • 对于前景路径,我们直接沿通道维度将这两个输入进行拼接,然后应用一系列的 Conv-BN-ReLU(卷积、批归一化、ReLU激活)层,得到前景特征。
  • 对于背景路径,则设计了背景注意力组件,选择性地关注背景信息。

前景路径得到前景特征Ffg。背景路径得到背景特征Fbg。 

前景注意力图,由前一层解码器的特征图经过sigmoid得到;背景注意力图,由1减去前景注意力图得来。 最后,将前景特征Ffg、背景特征Fbg、前一层解码器特征拼接,得到本层输出。

Loss Function

CTO是一个多任务模型,包含内部和边界分割,定义一个总体损失函数来共同优化这两个任务:

整体损失由主要的内部分割损失L_seg和边界损失L_bnd组成。在边界检测损失中,仅考虑来自 BEM 的预测结果,该模块将编码器的高层特征图和低层特征图作为输入。

Interior Segmentation Loss

L_seg是交叉熵损失L_CE和平均交并比 mIoU 损失L_mIoU的加权和:

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第3张图片

Boundary Loss

边界损失 L_bnd考虑到边界检测中前景和背景像素之间的类别不平衡问题,因此采用Dice损失:

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第4张图片

实验

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第5张图片

 【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第6张图片

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第7张图片

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第8张图片

【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation_第9张图片

关键代码

CTO_net.py

# https://github.com/xiaofang007/CTO/blob/main/CTOTrainer/network/CTO_net.py

class ConvBNR(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
        super(ConvBNR, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class Conv1x1(nn.Module):
    def __init__(self, inplanes, planes):
        super(Conv1x1, self).__init__()
        self.conv = nn.Conv2d(inplanes, planes, 1)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x


class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()
        self.reduce1 = Conv1x1(256, 64)
        self.reduce4 = Conv1x1(512, 256)
        self.block = nn.Sequential(
            ConvBNR(320 + 64, 256, 3),
            ConvBNR(256, 256, 3),
            nn.Conv2d(256, 1, 1))

    def forward(self, x1, x11, p2):
        size = x1.size()[2:]
        x1 = self.reduce1(x1)
        x11 = self.reduce1(x11)
        p2 = self.reduce4(p2)
        p2 = F.interpolate(p2, size, mode='bilinear', align_corners=False)
        out = torch.cat((x1, x11), dim=1)
        out = torch.cat((out, p2), dim=1)
        out = self.block(out)

        return out



class EFM(nn.Module):
    def __init__(self, channel):
        super(EFM, self).__init__()
        t = int(abs((log(channel, 2) + 1) / 2))
        k = t if t % 2 else t + 1
        self.conv2d = ConvBNR(channel, channel, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, att):
        if c.size() != att.size():
            att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
        x = c * att + c
        x = self.conv2d(x)
        wei = self.avg_pool(x)
        wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        wei = self.sigmoid(wei)
        x = x * wei

        return x

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class DM(nn.Module):
    def __init__(self):
        super(DM, self).__init__()
        self.predict3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
            nn.Conv2d(64, 1, kernel_size=1)
        )
        self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, xr, dualattention):
        crop_3 = F.interpolate(dualattention, xr.size()[2:], mode='bilinear', align_corners=False)
        re3_feat = self.predict3(torch.cat([xr, crop_3], dim=1))
        x = -1*(torch.sigmoid(crop_3)) + 1
        x = x.expand(-1, 64, -1, -1).mul(xr)
        x = F.relu(self.ra2_conv2(x))
        x = F.relu(self.ra2_conv3(x))
        ra3_feat = self.ra2_conv4(x)
        x = ra3_feat + crop_3 + re3_feat


        return x


class _DAHead(nn.Module):
    def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
        super(_DAHead, self).__init__()
        self.aux = aux
        inter_channels = in_channels // 4
        self.conv_p1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.conv_c1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.pam = _PositionAttentionModule(inter_channels, **kwargs)
        self.cam = _ChannelAttentionModule(**kwargs)
        self.conv_p2 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.conv_c2 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.out = nn.Sequential(
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, nclass, 1)
        )
        if aux:
            self.conv_p3 = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, nclass, 1)
            )
            self.conv_c3 = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, nclass, 1)
            )

    def forward(self, x):
        feat_p = self.conv_p1(x)
        feat_p = self.pam(feat_p)
        feat_p = self.conv_p2(feat_p)

        feat_c = self.conv_c1(x)
        feat_c = self.cam(feat_c)
        feat_c = self.conv_c2(feat_c)

        feat_fusion = feat_p + feat_c

        outputs = []
        fusion_out = self.out(feat_fusion)
        outputs.append(fusion_out)
        if self.aux:
            p_out = self.conv_p3(feat_p)
            c_out = self.conv_c3(feat_c)
            outputs.append(p_out)
            outputs.append(c_out)

        return tuple(outputs)

def run_sobel(conv_x, conv_y, input):
    g_x = conv_x(input)
    g_y = conv_y(input)
    g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
    return torch.sigmoid(g) * input

def get_sobel(in_chan, out_chan):
    '''
    filter_x = np.array([
        [3, 0, -3],
        [10, 0, -10],
        [3, 0, -3],
    ]).astype(np.float32)
    filter_y = np.array([
        [3, 10, 3],
        [0, 0, 0],
        [-3, -10, -3],
    ]).astype(np.float32)
    '''
    filter_x = np.array([
        [1, 0, -1],
        [2, 0, -2],
        [1, 0, -1],
    ]).astype(np.float32)
    filter_y = np.array([
        [1, 2, 1],
        [0, 0, 0],
        [-1, -2, -1],
    ]).astype(np.float32)
    filter_x = filter_x.reshape((1, 1, 3, 3))
    filter_x = np.repeat(filter_x, in_chan, axis=1)
    filter_x = np.repeat(filter_x, out_chan, axis=0)

    filter_y = filter_y.reshape((1, 1, 3, 3))
    filter_y = np.repeat(filter_y, in_chan, axis=1)
    filter_y = np.repeat(filter_y, out_chan, axis=0)

    filter_x = torch.from_numpy(filter_x)
    filter_y = torch.from_numpy(filter_y)
    filter_x = nn.Parameter(filter_x, requires_grad=False)
    filter_y = nn.Parameter(filter_y, requires_grad=False)
    conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
    conv_x.weight = filter_x
    conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
    conv_y.weight = filter_y
    sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan))
    sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan))
    return sobel_x, sobel_y

class GlobalFilter(nn.Module):
    def __init__(self, dim=32, h=64, w=33, fp32fft=True):
        super().__init__()
        self.complex_weight = nn.Parameter(
            torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
        )
        self.w = w
        self.h = h
        self.fp32fft = fp32fft

    def forward(self, x):
        b, _, a, b = x.size()
        x = x.permute(0, 2, 3, 1).contiguous()

        if self.fp32fft:
            dtype = x.dtype
            x = x.to(torch.float32)

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        #print(x.shape)
        weight = torch.view_as_complex(self.complex_weight)
       # print(x.shape)
        #print(weight.shape)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")

        if self.fp32fft:
            x = x.to(dtype)

        x = x.permute(0, 3, 1, 2).contiguous()

        return x

class ERB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ERB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x, relu=True):
        x = self.conv1(x)
        res = self.conv2(x)
        res = self.bn(res)
        res = self.relu(res)
        res = self.conv3(res)
        if relu:
            return self.relu(x + res)
        else:
            return x+res

class _PositionAttentionModule(nn.Module):
    """ Position attention module"""

    def __init__(self, in_channels, **kwargs):
        super(_PositionAttentionModule, self).__init__()
        self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
        self.alpha = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        feat_c = self.conv_c(x).view(batch_size, -1, height * width)
        attention_s = self.softmax(torch.bmm(feat_b, feat_c))
        feat_d = self.conv_d(x).view(batch_size, -1, height * width)
        feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
        out = self.alpha * feat_e + x

        return out


class _ChannelAttentionModule(nn.Module):
    """Channel attention module"""

    def __init__(self, **kwargs):
        super(_ChannelAttentionModule, self).__init__()
        self.beta = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        feat_a = x.view(batch_size, -1, height * width)
        feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
        attention = torch.bmm(feat_a, feat_a_transpose)
        attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
        attention = self.softmax(attention_new)

        feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
        out = self.beta * feat_e + x

        return out
        
class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()
        self.reduce1 = Conv1x1(256, 64)
        self.reduce4 = Conv1x1(2048, 256)
        self.block = nn.Sequential(
            ConvBNR(256 + 64, 256, 3),
            ConvBNR(256, 256, 3),
            nn.Conv2d(256, 1, 1))

    def forward(self, x4, x1):
        size = x1.size()[2:]
        x1 = self.reduce1(x1)
        x4 = self.reduce4(x4)
        x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False)
        out = torch.cat((x4, x1), dim=1)
        out = self.block(out)

        return out

def attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
        query.size(-1)
    )
    p_attn = F.softmax(scores, dim=-1)
    p_val = torch.matmul(p_attn, value)
    return p_val, p_attn

class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, patchsize, d_model):
        super().__init__()
        self.patchsize = patchsize
        self.query_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.value_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.key_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.output_linear = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_model),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        b, c, h, w = x.size()#8,255,64,64
        d_k = c // len(self.patchsize)
        output = []
        _query = self.query_embedding(x)#8,32,80,80
        _key = self.key_embedding(x)#8,32,80,80
        _value = self.value_embedding(x)#8,32,80,80
        attentions = []
        for (width, height), query, key, value in zip(
            self.patchsize,
            torch.chunk(_query, len(self.patchsize), dim=1),
            torch.chunk(_key, len(self.patchsize), dim=1),
            torch.chunk(_value, len(self.patchsize), dim=1),
        ):
            #print('-----------width, height):',x.size())
           # print('-----------x.size()):',x.size())
            
            #print('-----------len(self.patchsize):',len(self.patchsize))  # 4
            
            #print('-----------_query):',_query.shape)   #8,256,64,64
            
            #print('-----------query):',query.shape)  #8,64,64,64
            
            out_w, out_h = w // width, h // height#
            ## 1) embedding and reshape
            query = query.view(b, d_k, out_h, height, out_w, width)
           # print('-----------query):',query.shape)
            
           # print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
            query = (
                query.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            key = key.view(b, d_k, out_h, height, out_w, width)
            key = (
                key.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            value = value.view(b, d_k, out_h, height, out_w, width)
            value = (
                value.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            y, _ = attention(query, key, value)

            # 3) "Concat" using a view and apply a final linear.
            y = y.view(b, out_h, out_w, d_k, height, width)
            y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
            attentions.append(y)
            output.append(y)

        output = torch.cat(output, 1)
        self_attention = self.output_linear(output)

        return self_attention



class TransformerBlock(nn.Module):
    """
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, patchsize, in_channel=256):
        super().__init__()
        self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
        self.feed_forward = FeedForward2D(
            in_channel=in_channel, out_channel=in_channel
        )

    def forward(self, rgb):
        self_attention = self.attention(rgb)
        output = rgb + self_attention
        output = output + self.feed_forward(output)
        return output

class PatchTrans(BaseNetwork):
    def __init__(self, in_channel, in_size):#32,80
        super(PatchTrans, self).__init__()
        self.in_size = in_size#80

        patchsize = [
              (32,32),#80,80
              (16,16),#40,40
              (8,8),#20,20
              (4,4),#10,10
        ]

        self.t = TransformerBlock(patchsize, in_channel=in_channel)

    def forward(self, enc_feat):
        output = self.t(enc_feat)
        return output

class multi(nn.Module):
    def __init__(self, channel):
        super(EFM, self).__init__()
        t = int(abs((log(channel, 2) + 1) / 2))
        k = t if t % 2 else t + 1
        self.conv2d = ConvBNR(channel, channel, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, att):
        if c.size() != att.size():
            att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
        x = c * att 
        #x = self.conv2d(x)
        #wei = self.avg_pool(x)
        #wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        #wei = self.sigmoid(wei)
        #x = x * wei

        return x

class CTO(nn.Module):
    def __init__(self,seg_classes):
        super(CTO, self).__init__()
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)
        # if self.training:
        # self.initialize_weights()
        self.fft = GlobalFilter(dim = 3 , h=256, w=129, fp32fft= True)
        
        self.multi_trans = PatchTrans(in_channel=256,in_size=64)
        
        
        
        self.num_class = seg_classes
        self.eam = EAM()
        self.sobel_x1, self.sobel_y1 = get_sobel(256, 1)
        self.sobel_x2, self.sobel_y2 = get_sobel(512, 1)
        self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1)
        self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1)
        
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
        self.upsample_3 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
        
        self.erb_db_1 = ERB(256, self.num_class)
        self.erb_db_2 = ERB(512, self.num_class)
        self.erb_db_3 = ERB(1024, self.num_class)
        self.erb_db_4 = ERB(2048, self.num_class)
        
        self.head = _DAHead(2048+256, 2048, aux=False)

        

        self.reduce1 = Conv1x1(256, 64)
        self.reduce2 = Conv1x1(512, 64)
        self.reduce3 = Conv1x1(1024, 64)
        self.reduce4 = Conv1x1(2048, 64)
        self.reduce5 = Conv1x1(2048, 1)

        self.dm1 = DM()
        self.dm2 = DM()
        self.dm3 = DM()
        self.dm4 = DM()

        self.predictor1 = nn.Conv2d(64, self.num_class, 1)
        self.predictor2 = nn.Conv2d(64, self.num_class, 1)
        self.predictor3 = nn.Conv2d(64, self.num_class, 1)
        self.predictor4 = nn.Conv2d(64, self.num_class, 1)

    # def initialize_weights(self):
    # model_state = torch.load('./models/resnet50-19c8e357.pth')
    # self.resnet.load_state_dict(model_state, strict=False)

    def forward(self, x):
        fft_fea = self.fft(x)#3,256,256
        x1, x2, x3 ,x4= self.resnet(x)#[16, 256, 64, 64]  [16, 512, 32, 32]   [16, 1024, 16, 16]   [16, 2048, 8, 8]
        
        trans = self.multi_trans(x1)#16,256,64,64
        
        s1 = run_sobel(self.sobel_x1, self.sobel_y1, x1)
        s4 = run_sobel(self.sobel_x4, self.sobel_y4, x4)
       
        edge = self.eam(s4, s1)
        edge_att = torch.sigmoid(edge)#[16, 1, 64, 64]
        
        trans = F.interpolate(trans,x4.size()[2:], mode='bilinear', align_corners=False)#256,8,8
        dual_attention = self.head(torch.cat([trans, x4], dim=1))[0]  #2048,8,8
        
        x1a = x1*edge_att
        edge_att2 = F.interpolate(edge_att, x2.size()[2:], mode='bilinear', align_corners=False)
        x2a = x2*edge_att2
        edge_att3 = F.interpolate(edge_att, x3.size()[2:], mode='bilinear', align_corners=False)
        x3a = x3*edge_att3
        
        #x1a = self.efm1(x1, edge_att)
        #x2a = self.efm2(x2, edge_att)
       # x3a = self.efm3(x3, edge_att)
       # x4a = self.efm4(x4, edge_att)
        
        x1r = self.reduce1(x1a)  
        x2r = self.reduce2(x2a)#128,32,32
        x3r = self.reduce3(x3a)#256,16,16
        
        dual_attention = self.reduce4(dual_attention)
       
        c3 = self.dm3(x3r, dual_attention) #256 16 16
        c2 = self.dm2(x2r, c3)  #128 32 32
        c1 = self.dm1(x1r, c2) #64 64 64
        

        o3 = self.predictor3(c3)
        o3 = F.interpolate(o3, scale_factor=16, mode='bilinear', align_corners=False)
        o2 = self.predictor2(c2)
        o2 = F.interpolate(o2, scale_factor=8, mode='bilinear', align_corners=False) 
        o1 = self.predictor1(c1)
        o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
        oe = F.interpolate(edge_att, scale_factor=4, mode='bilinear', align_corners=False)

        return  o3, o2, o1, oe

你可能感兴趣的:(Transformer系列,论文笔记,深度学习,人工智能,计算机视觉)