《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》

论文解读《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awareness and Local Context Constraints》

《半监督医学图像使用跨模型的分割伪监督与形状意识和局部语义约束》

论文地址:论文地址
论文出处:MICCAI2022
论文代码:代码链接

一、摘要:
(1) 在半监督医学图像分割中,可用于训练的标记数据数量有限。为了克服这些挑战,一个基于跨模型伪监督的新框架,该框架利用形状生成预测和局部语义约束。

(2) 框架由两个并行网络组成,一个是形状感知网络( a shape-aware network ),另一个是形状不可知网络(a shape-agnostic network)
形状感知网络通过添加其他网络的预测作为输入,隐式(implicitly)地捕获目标区域的形状信息(information on the shape of target regions)。
另一方面,形状不可知网络利用MonteCarlo dropout不确定性估计为另一个网络生成可靠的伪标签。

MonteCarlo dropout不确定性 :对于一个模型的输出结果,我们想得到这个结果的方差来计算模型不确定性(认知不确定性)。而模型的参数是固定的,一个单独输出值是得不到方差的。如果说——我们能够用同一个模型,对同一个样本进行T次预测,而且这T次的预测值各不相同,就能够计算方差。

问题是同一个模型同一个样本,怎么得到不同的输出呢?我们可以让学到的模型参数不是确定的值,而是服从一个分布,那么模型参数就可以从这个分布中采样得到,每一次采样,得到的模型参数都是不同的,这样模型产生的结果也是不同的,我们的目的就达到了。

但是如何让模型的参数不是确定的而是服从一个分布呢? 现成的dropout就是,使用dropout来训练模型时,模型的参数可以看成是服从一个伯努利分布(比如dropout radio =0.5,一种说法是:这层神经元中有一半会被dropout,换种说法就是——这层的每个神经元都有0.5的概率被dropout,这就是伯努利分布)。但是我们估计模型不确定度肯定是在训练好的模型上,也就是测试模型时估计的。所以我们只需要在预测的时候,仍然将dropout打开,预测 T 次,取预测的平均值就是最终的预测值。并且通过平均值就可以得到方差,这样就得到深度学习的不确定度了。这种方法也被称为MC Dropout贝叶斯神经网络。

但其实,MC dropout 用起来就简单了,不需要修改现有的神经网络模型,只需要神经网络模型中带 dropout 层,无论是标准的 dropout 还是其变种,如 drop-connect,都是可以的。

在训练的时候,MC dropout 表现形式和 dropout 没有什么区别,按照正常模型训练方式训练即可。

在测试的时候,在前向传播过程,神经网络的 dropout 是不能关闭的。这就是和平常使用的唯一的区别。

MC dropout 的 MC 体现在我们需要对同一个输入进行多次前向传播过程,这样在 dropout 的加持下可以得到“不同网络结构”的输出,将这些输出进行平均和统计方差,即可得到模型的预测结果及 uncertainty。而且,这个过程是可以并行的,所以在时间上可以等于进行一次前向传播。

(3) 该框架还包括一个新的损失函数(The local context loss),使网络能够学习分割的局部上下文。

二、引言
(1) 提出了一种新的分割框架,以提高精度,并在半监督场景中生成合理的结果。与上述方法不同,所提框架不施加显式的形状约束,而是利用输入层的形状相关信息来隐式捕获形状约束。采用交叉教学策略的两个平行网络,每个网络在未标记数据上为另一个生成伪标签。

(2) 第一个网络接收原始图像作为输入来执行标准的分割操作。另一方面,第二个网络接收相同的输入以及由第一个网络产生的相应概率图,以隐式捕获形状信息。形状意识网络纠正预测边界,为另一个生成更准确的伪标签。为减轻噪声伪标签的影响,用一种 不确定性估计策略( an uncertainty estimation strategy) 增强了形状无关网络,可以过滤不可靠区域。

(3) 此外,还设计了一种新的损失函数,称为局部上下文损失,在整个图像的局部区域进行准确的预测。这与关注【1】全局重叠的标准Dice损失[16]和【2】流行的交叉熵损失形成对比,前者专注于全局重叠,因此可能导致图像某些区域的形状表示较差,后者在不考虑其局部上下文的情况下纠正单个像素的预测。与 [4](Semi-supervised semantic segmentation with cross pseudo supervision)相比,该方法增加了两种创新策略,使分割管道更鲁棒并提高性能:形状感知以增强指导,局部上下文损失以加强局部区域的准确性。

(4) 虽然[9](Iterative instance segmentation)的实例分割方法也利用了形状先验,但没有在迭代过程中使用这种先验,而是通过形状无关网络和形状意识网络的交叉教学。

(5) 综上所述,本文工作在医学图像半监督分割方面做出了3个重要贡献:
1、针对该任务设计了一种新的跨模型伪监督框架,结合形状意识和局部上下文约束,在有限的标记数据下产生准确的结果。
2、所提出的半监督分割框架是第一个在半监督网络的输入层纳入形状信息的框架,而不需要复杂的形状约束。
3、提出了一种新的损失函数 ,帮助网络学习分割的局部上下文,增强了整个图像的形状表示。

三、方法
(1) 在集合在这里插入图片描述
中有N个标记示例;Yi
是对应的标准标签。H × W为图像大小,C为标准类别的个数。

在集合在这里插入图片描述中有M个未标记示例。
Xi
代表未做处理的数据

(2) 完整的训练集D = S∪U,目标是学习一个分割模型F。

3.1 具有形状意识和局部环境约束的跨模型伪监督
《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第1张图片
图1 具有形状意识和局部上下文约束的跨模型伪监督框架的流程图。Poi (l)和Poi (u)分别表示来自网络Fo的监督和无监督概率图

在S数据集中,一个标签数据Xi都在监督网络中训练,并且使用标准的分割损失函数和标准的标签Yi。

在U数据集中,无标签数据Xi在网络互相提供伪标签用来引导其他参数更新。

在这里插入图片描述

在这里插入图片描述

最后,整个半监督框架的目标损失如下:
在这里插入图片描述
为了限制资源消耗并与其他半监督方法进行公平的比较,在推理时,只使用形状无关网络Fo的输出来获得最终的分割预测。

3.2 损失函数

Shape-aware Loss考虑了形状。通常,所有损失函数都在像素级起作用,Shape-aware Loss会计算平均点到曲线的欧几里得距离,即预测分割到ground truth的曲线周围点之间的欧式距离,并将其用作交叉熵损失函数的系数

(1)监督的损失
在这里插入图片描述

The local context loss 上下文损失
在这里插入图片描述

这里K × L是通过分割图像得到的局部区域尺寸。局部预测P(k,l) i和标准值掩模Y (k,l) i的大小为(h = h / k, w = w / l),这些局部特征中的坐标(r, s)对应于Pi和Yi中的位置[h× (K−1)+ r, w×(L−1)+ s]。

###监督损失函数
class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob)
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict & target shape do not match'
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes

class Block_DiceLoss(nn.Module):
    def __init__(self, n_classes, block_num):
        super(Block_DiceLoss, self).__init__()
        self.n_classes = n_classes
        self.block_num = block_num
        self.dice_loss = DiceLoss(self.n_classes)
    def forward(self, inputs, target, weight=None, softmax=False):
        shape = inputs.shape
        img_size = shape[-1]
        div_num = self.block_num
        block_size = math.ceil(img_size / self.block_num)
        if target is not None:
            loss = []
            for i in range(div_num):
                for j in range(div_num):
                    block_features = inputs[:, :, i * block_size:(i + 1) * block_size, j * block_size:(j + 1) * block_size]
                    block_labels = target[:, i * block_size:(i + 1) * block_size, j * block_size:(j + 1) * block_size]
                    tmp_loss = self.dice_loss(block_features, block_labels.unsqueeze(1))
                    loss.append(tmp_loss)
            loss = torch.stack(loss).mean()
        return loss
def create_model(ema=False, in_chns=1):
        # Network definition
     model = net_factory(net_type=args.model, in_chns=in_chns,
                    class_num=num_classes)
if ema:
    for param in model.parameters():
        param.detach_()
return model
model1 = create_model()
outputs1  = model1(volume_batch)
outputs_soft1 = torch.softmax(outputs1, dim=1)
volume_batch1 = torch.cat((volume_batch + noise, outputs_soft1), dim=1)
outputs2 = model2(volume_batch1)
outputs_soft2 = torch.softmax(outputs2, dim=1)
pseudo_outputs2 = torch.argmax(outputs_soft2[args.labeled_bs:].detach(), dim=1, keepdim=False)

ce_loss = CrossEntropyLoss()#####交叉熵损失
loss_block_dice1 = block_loss(outputs_soft1[:args.labeled_bs], label_batch[:args.labeled_bs])
loss1 = 0.5 * (ce_loss(outputs1[:args.labeled_bs], label_batch[:][:args.labeled_bs].long()) + loss_block_dice1)

pseudo_supervision1 = torch.mean(ce_loss1(outputs1[args.labeled_bs:], pseudo_outputs2))
model1_loss = loss1 + consistency_weight * pseudo_supervision1

(2)无监督的损失

《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第2张图片

这些伪标签直接用来构建一个损失来指导Fo:
在这里插入图片描述

在这里插入图片描述
相比之下,Fo提供的伪标签准确率较低。为了缓解这个问题,使用了一种基于蒙特卡罗Dropout的不确定性估计策略。
《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第3张图片

unlabeled_volume_batch = volume_batch[args.labeled_bs:]
T = 8
_, _, w, h = unlabeled_volume_batch.shape
volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1)
stride = volume_batch_r.shape[0] // 2
preds = torch.zeros([stride * T, num_classes, w, h]).cuda()
for i in range(T // 2):
    ema_inputs = volume_batch_r + \
                 torch.clamp(torch.randn_like(
                     volume_batch_r) * 0.1, -0.05, 0.05)
    with torch.no_grad():
        preds[2 * stride * i:2 * stride *
                             (i + 1)] = model1(ema_inputs)
            preds = F.softmax(preds, dim=1)
            preds = preds.reshape(T, stride, num_classes, w, h)
            preds = torch.mean(preds, dim=0)
uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)
threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(iter_num,max_iterations)) * np.log(2)
mask = (uncertainty < threshold).float()

outputs2 = model2(volume_batch1)
outputs1  = model1(volume_batch)
outputs_soft1 = torch.softmax(outputs1, dim=1)
pseudo_outputs1 = torch.argmax(outputs_soft1[args.labeled_bs:].detach(), dim=1, keepdim=False)
pseudo_supervision2 = torch.mean(ce_loss1(outputs2[args.labeled_bs:], pseudo_outputs1) * mask.squeeze(1))

loss_block_dice2 = block_loss(outputs_soft2[:args.labeled_bs], label_batch[:args.labeled_bs])
loss2 = 0.5 * (ce_loss(outputs2[:args.labeled_bs], label_batch[:][:args.labeled_bs].long()) + loss_block_dice2)

model2_loss = loss2 + consistency_weight * pseudo_supervision2

四、实验结果分析
4.1
《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第4张图片
《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第5张图片
Average surface distance 平均表面距离

这个指标就是P中所有点的表面距离的平均。这个指标又可称为 Average Symmetric Surface Distance (ASSD) ,它也是医疗图像分割竞赛CHAOS中的一个评估指标。

《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第6张图片

《Semi-supervised Medical Image Segmentation Using Cross-Model Pseudo-Supervision with Shape Awarene》_第7张图片

为验证所提出方法的有效性,将其与全/有限监督(FS/LS)基线和几个最先进的半监督分割方法进行了比较,
包括:(1)熵最小化(EM) [21];(2)深度对抗网络(DAN) [28];(3) Mean Teacher (MT) [19];(4)不确定性感知平均教师(UAMT) [27];(5)插值一致性训练(ICT) [20];(6)不确定性修正金字塔一致性(URPC) [15];(7)交叉伪监督(CPS)所有被测试的方法都使用相同的实验配置,并且在推理阶段不使用任何后处理操作。

图2 提供了在ACDC和PROMISE12数据集上测试方法的可视化比较。所提出方法比其他方法产生了更可靠的预测,并更好地保留了解剖形态。

4.2 消融实验
表1给出了在ACDC数据集上进行的消融实验的结果,显示了向输入层添加形状信息的有效性,以及不同参数配置(LLC(K, L))对我们的局部上下文损失的好处。当向输入层添加形状信息并使用标准Dice损失(shape information to the input
layer and using a standard Dice loss,S+ lice)时,与基线(普通跨模型伪监督)相比,DSC在RV、Myo和LV(右心室、心肌和左心室)上分别提高了3.28%、3.46%和2.56%。另一方面,去除形状增强模型中的不确定性估计(UE),对相同类别的DSC改进分别降低了0.91%、1.00%和1.46%。这些结果表明,将形状先验作为网络的额外输入有助于提高精度,而使用不确定性估计来过滤低置信区域可以进一步提高性能。

与Myo和LV多为圆形相比,RV的形状更加复杂不规则(见图2),因此更有利于局部上下文损失LLC。

五、结论
(1) 提出了一种新的半监督图像分割框架,由两个网络组成,以交叉教学策略相互提供伪标签:
1)一个形状无关的网络( shape-agnostic network) ,只接收原始图像作为输入,并利用基于蒙特卡罗dropout的不确定性估计为另一个网络生成可靠的伪标签;
2) 形状意识网络(shape-aware network) ,将第一个网络的预测作为附加输入,以纳入形状信息。
(2) 本文还提出一种损失,帮助网络学习整个图像局部区域的精确分割。在ACDC和PROMISE12数据集上的实验结果表明,与最近的半监督分割方法相比,该方法具有更好的性能,并展示了该框架的不同组件的好处。在某些情况下,在局部上下文损失中赋予每个区域同等的重要性,可能不是学习分割复杂结构的最佳方法。

你可能感兴趣的:(计算机视觉,深度学习,人工智能)