【医学图像系列】U-Net v2: Rethinking the Skip Connections of U-Net for Medical Image Segmentation

论文链接:https://arxiv.org/pdf/2311.17791.pdf
代码链接:https://github.com/yaoppeng/U-Net_v2/blob/master/unet_v2/UNet_v2.py
【医学图像系列】U-Net v2: Rethinking the Skip Connections of U-Net for Medical Image Segmentation_第1张图片

    def forward(self, x):
        seg_outs = []
        f1, f2, f3, f4 = self.encoder(x)

        f1 = self.ca_1(f1) * f1
        f1 = self.sa_1(f1) * f1
        f1 = self.Translayer_1(f1)

        f2 = self.ca_2(f2) * f2
        f2 = self.sa_2(f2) * f2
        f2 = self.Translayer_2(f2)

        f3 = self.ca_3(f3) * f3
        f3 = self.sa_3(f3) * f3
        f3 = self.Translayer_3(f3)

        f4 = self.ca_4(f4) * f4
        f4 = self.sa_4(f4) * f4
        f4 = self.Translayer_4(f4)

        f41 = self.sdi_4([f1, f2, f3, f4], f4)
        f31 = self.sdi_3([f1, f2, f3, f4], f3)
        f21 = self.sdi_2([f1, f2, f3, f4], f2)
        f11 = self.sdi_1([f1, f2, f3, f4], f1)

class SDI(nn.Module):
    def __init__(self, channel):
        super().__init__()

        self.convs = nn.ModuleList(
            [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)])

    def forward(self, xs, anchor):
        ans = torch.ones_like(anchor)
        target_size = anchor.shape[-1]

        for i, x in enumerate(xs):
            if x.shape[-1] > target_size:
                x = F.adaptive_avg_pool2d(x, (target_size, target_size))
            elif x.shape[-1] < target_size:
                x = F.interpolate(x, size=(target_size, target_size),
                                      mode='bilinear', align_corners=True)

            ans = ans * self.convs[i](x)

        return ans

过去的UNet在上采样的过程中每次通过拼接的方式复用一个stage的特征
这里则是每个stage都会通过哈达玛积的方式复用编码器中所有stage的特征
在复用前会对编码器每个stage 串联通道、空间注意力做增强

你可能感兴趣的:(医学图像,深度学习,人工智能,计算机视觉)