SegAN:Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation-笔记

这次介绍将对抗网络第一次应用到医学图像分割的文章。
代码
作者提出用对抗网络进行分割,损失函数使用多尺度L1损失函数。该论文的判别器和生成器分别为critic network (C)和segmentor network (S)。
首先介绍一下SegAN与原始的GAN的区别。

1,损失函数的区别。经典的GAN的generator和discriminator的损失是分开的。本论文为segmentor和critic networks提出一个多尺度损失函数。Critic考虑分割结果和真实label多尺度的特征差异来最大化多尺度L1函数。
2,我们的segmentor S是全卷积网络,只拿通过critic的梯度进行训练,最小化和critic一样的损失函数。
3,SegAN是端对端的,不需要多分辨率的网络输入。

模型

SegAN:Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation-笔记_第1张图片
主要有俩部分:segmentor network S,critic network C。

segmentor network S:类似unet的u型网络,
critic network C:有两个输入,一个是原图和分割网络预测结果mask之后的图,另一个是原图和label进行mask之后的原图。(mask表示相乘)
1.Segmentor,使用maxpooling,4X4@2;使用upsampling放大一倍,再加一个3X3@1的卷积层。
2. Critic,和收缩层有相似的结构。然后计算每层的L1损失,相当于是对不同分辨率都会计算损失,所以说,使用这种分层特征,这个损失可以捕获不同分辨率像素之间的空间关系(像素级特征,低维即超像素特征,中间即patches特征)。使用leakyReLU,BN。没有使用pooling,使用卷积2X2@2代替pooling。

S和C通过对抗的方式轮流训练。

损失函数

传统损失:
SegAN:Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation-笔记_第2张图片
P_data表示原始数据的分布,P_Z是一个随机分布,用来产生随机输入。G表示生成器,D表示判别器,G的目的是通过P_Z产生真实分布P_data。让D分辨不出是产生的还是原图。D要尽力分辨真假。所以说,D要最大化目标函数,G要最小化目标函数。
本文提出的损失函数:
虽然和传统的GAN有相同的对抗训练过程,但是二者目的不一样。传统的GAN的目的是找到数据分布之间的映射,但是本文的目的是输入图像和分割mask的映射。本文损失函数如下:
SegAN:Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation-笔记_第3张图片
N表示训练图像个数。l_mae表示Mean Absolute Error (MAE) 或者L1距离,xn ◦ S(xn )和xn ◦ yn分别表示原图和预测结果,label进行mask操作。f_C (x),表示C网络对输入的分层特征的提取。具体的l_mae如下:
这里写图片描述
L表示C网络总共的层的个数。这里写图片描述表示从x提取的第i层的特征。

Training SegAN

首先,我们固定S,训练C一个step,然后使用通过C的梯度训练一个step S。这两个梯度都是从同一个损失函数计算。
Batch size为64,学习率为0.00002。作者通过网格搜索,确定两个网络中pooling的最优个数,S网络为4个,C网络为3个。

作者证明了该论文的损失是有界的,而且最后会收敛。

作者对原图进行裁剪,但是保证大脑在裁剪后的图中。使用的评价标准Dice,Precision和Sensitivity。实际上网络的输入是160X160,对裁剪后的180X180,进行随机裁剪得到,算是数据增强。输入使用T1c,T2, FLAIR三种模态。

作者尝试了L2,但是效果不好。C网络使用pooling,发现效果没有提升。探索C网络的输入,使用label和S网络的预测结果,作为C网络的输入,但是效果不好。

缺点:1,对小的区域分割结果不好。作者猜测,使用pixel-level loss的损失可能会好一点。分割不同的区域使用不同的结构。2,如果类别多的话,计算消耗较大。

S1-1C:对于每一类,都设计一个S和C,这样,几类就有几个S-C。这种方法就会导致同一个像素会属于多个类别。
S3-1C:只有一个S和C网络。S输出3个通道,分别是3类的概率。这三个输出生成的三个mask图生成3个通道,作为C的输入。
S3-3C:一个S网络,3个分开的C分别处理每个类别。
S3-3C single-scale loss models:这个是对损失做文章。又分为两类:S3-3C-s0,只是用C的输入层计算损失;S3-3C-s3只是用C的输出层计算损失。

结果如图2。
SegAN:Adversarial Network with Multi-scale L1 Loss for Medical Image Segmentation-笔记_第4张图片
最终选择了S3-3C。

你可能感兴趣的:(论文笔记,论文笔记)