本文是论文《Unsupervised Bidirectional Cross-Modality Adaptation via Deeply Synergistic Image and Feature Alignment for Medical Image Segmentation》的阅读笔记。
文章提出了一个名为 SIFA(Synergistic Image and Feature Alignment)的无监督域适应框架。SIFA 的代码见 github。SIFA 从图像和特征两个角度引入了对齐的协同融合。
域适应就是将从源域学习到的知识迁移到目标域中,在此之前 CycleGAN 在域适应方面取得了很好的效果。
SIFA 的一个关键特点是图像变换和分割任务的共享编码器。通过参数共享,本框架中的图像对齐和特征对齐能够协同工作,减少端到端训练过程中的域偏移(domain shift)。同时,另一个研究方向是特征对齐,目的是在对抗性学习的情况下提取深度神经网络的域不变特征。
由于域偏移,跨域之间的图片通常看起来不同,而图像对齐的目的就是减少源域图像和目标域图像之间的这种差异。即给定一个有标签的来自源域的数据集 { x i s , y i s } i = 1 N \{x_i^s,y_i^s\}_{i=1}^N {xis,yis}i=1N,以及一个无标签的来自目标域的数据集 { x i t } j = 1 M \{x_i^t\}_{j=1}^M {xit}j=1M,使得源域图像 x i s x_i^s xis 尽可能的看起来像 目标域图像 x i t x_i^t xit。转换后的图像不仅要看起来像来自目标域,而且还应该保留源域的结构语义内容。
上图是网络的整体结构示意图,可结合以下描述来加以理解。
使用一个生成器 G t G_t Gt 将源域图像转换成与目标域相似的图像,即 G t ( x s ) = x s → t G_t(x^s)=x^{s\rightarrow t} Gt(xs)=xs→t,并使用一个判别器 D t D_t Dt 来判断生成的图像是真正来自目标域还是生成的。这个 GAN 的目标函数为:
L adv t ( G t , D t ) = E x t ∼ X t [ log D t ( x t ) ] + E x s ∼ X s [ log ( 1 − D t ( G t ( x s ) ) ) ] \begin{aligned} \mathcal{L}_{\text {adv}}^{t}\left(G_{t}, D_{t}\right)=& \mathbb{E}_{x^{t} \sim X^{t}}\left[\log D_{t}\left(x^{t}\right)\right]+\\ & \mathbb{E}_{x^{s} \sim X^{s}}\left[\log \left(1-D_{t}\left(G_{t}\left(x^{s}\right)\right)\right)\right] \end{aligned} Ladvt(Gt,Dt)=Ext∼Xt[logDt(xt)]+Exs∼Xs[log(1−Dt(Gt(xs)))]
为了让转换得到的图像 x s → t x^{s\rightarrow t} xs→t 保留源域的特征,通常使用一个反向的生成器来促进图像的循环一致性。图中的 E 是特征编码器,U 是解码器,E 和 U 加起来就相当于一个生成器 G s G_s Gs,即 G s = E ∘ U G_s=E\circ U Gs=E∘U ,它可以将转换得到的目标域图像 x s → t x^{s\rightarrow t} xs→t 再转换回源域。并通过源域的判别器 D s D_s Ds 进行判别,其对抗损失为 L a d v s \mathcal{L}_{adv}^s Ladvs,和目标域上的 GAN 的训练方式一致。通过源域-目标域-源域( x s → t → s = U ( E ( G t ( x s ) ) ) x^{s \rightarrow t \rightarrow s}=U\left(E\left(G_{t}\left(x^{s}\right)\right)\right) xs→t→s=U(E(Gt(xs))))或目标域-源域-目标域( x t → s → t = G t ( U ( E ( x t ) ) ) x^{t \rightarrow s \rightarrow t}=G_{t}\left(U\left(E\left(x^{t}\right)\right)\right) xt→s→t=Gt(U(E(xt))))的转换就得到了图像的循环一致性损失,即:
L c y c ( G t , E , U ) = E x s ∼ X s ∥ U ( E ( G t ( x s ) ) ) − x s ∥ 1 + E x t ∼ X t ∥ G t ( U ( E ( x t ) ) ) − x t ∥ 1 \begin{aligned} \mathcal{L}_{\mathrm{cyc}}\left(G_{t}, E, U\right)=& \mathbb{E}_{x^{s} \sim X^{s}}\left\|U\left(E\left(G_{t}\left(x^{s}\right)\right)\right)-x^{s}\right\|_{1}+\\ & \mathbb{E}_{x^{t} \sim X^{t}}\left\|G_{t}\left(U\left(E\left(x^{t}\right)\right)\right)-x^{t}\right\|_{1} \end{aligned} Lcyc(Gt,E,U)=Exs∼Xs∥U(E(Gt(xs)))−xs∥1+Ext∼Xt∥∥Gt(U(E(xt)))−xt∥∥1
图中的 C 是一个像素级的分类器,E 和 C 加起来 E ∘ C E\circ C E∘C 就相当于一个目标域的分割网络,它的输入包括 x s → t , y s , x t x^{s\rightarrow t},y^s,x^t xs→t,ys,xt,输出是 x s → t , x t x^{s\rightarrow t},x^t xs→t,xt 的分割标签,分割网络通过最小化一个混合损失(分割损失)来优化:
L s e g ( E , C ) = H ( y s , C ( E ( x s → t ) ) + Dice ( y s , C ( E ( x s → t ) ) ) \mathcal{L}_{s e g}(E, C)=H\left(y^{s}, C\left(E\left(x^{s \rightarrow t}\right)\right)+\operatorname{Dice}\left(y^{s}, C\left(E\left(x^{s \rightarrow t}\right)\right)\right)\right. Lseg(E,C)=H(ys,C(E(xs→t))+Dice(ys,C(E(xs→t)))
其中第一项是交叉熵,第二项是 Dice 损失。
为解决跨域的域偏移问题,文章提出了另外的判别器来从特征对齐的角度来减少生成的目标图像 x s → t x^{s\rightarrow t} xs→t 和真正的目标图像 x t x^t xt 的 domain gap。为了对齐以上两种图像的特征,通常的方法是在特征空间直接使用对抗学习,但是特征空间一般是高维的,很难直接对齐。所以文章使用的方法是在两个低维的空间内使用对抗学习,一个是语义预测空间,另一个是生成图像空间。
使用判别器 D p D_p Dp 来对分割网络生成的分割标签进行判别,如果两者的特征没有对齐的话,就通过反向传播对特征提取器 E 进行优化,从而减小生成的目标域图像 x s → t x^{s\rightarrow t} xs→t 和真正的目标域图像 x t x^t xt 的特征分布之间的差异。该对抗损失为:
L a d v p ( E , C , D p ) = E x s → t ∼ X s → t [ log D p ( C ( E ( x s → t ) ) ) ] + E x t ∼ X t [ log ( 1 − D p ( C ( E ( x t ) ) ) ) ] \begin{aligned} \mathcal{L}_{a d v}^{p}\left(E, C, D_{p}\right)=& \mathbb{E}_{x^{s \rightarrow t} \sim X^{s \rightarrow t}\left[\log D_{p}\left(C\left(E\left(x^{s \rightarrow t}\right)\right)\right)\right]+} \\ & \mathbb{E}_{x^{t} \sim X^{t}\left[\log \left(1-D_{p}\left(C\left(E\left(x^{t}\right)\right)\right)\right)\right]} \end{aligned} Ladvp(E,C,Dp)=Exs→t∼Xs→t[logDp(C(E(xs→t)))]+Ext∼Xt[log(1−Dp(C(E(xt))))]
低级特征可能和高级特征的对齐情况并不一样,所以使用额外的和编码器低层的输出相关的像素级分类器来产生额外的辅助预测,然后通过一个判别器来对这些额外预测进行判别。这增强了低级特征的对齐,如此一来, L s e g \mathcal{L}_{seg} Lseg 和 L a d v \mathcal{L}_{adv} Ladv 的表达式就需要进行调整了,它们分别被拓展为 L s e g i ( E , C i ) \mathcal{L}_{seg}^i(E,C_i) Lsegi(E,Ci) 和 L a d v P i ( E , C i , D p i ) \mathcal{L}_{adv}^{P_i}(E,C_i,D_{p_i}) LadvPi(E,Ci,Dpi),其中 i = 1 , 2 i={1,2} i=1,2, C 1 , C 2 C_1,C_2 C1,C2 表示连接到编码器不同层的两个分类器, D p 1 , D p 2 D_{p_1},D_{p_2} Dp1,Dp2 表示对两个分类器的输出进行判别的判别器。
对于生成器 E ∘ U E\circ U E∘U,为判别器 D s D_s Ds 增加一个辅助任务——判别生成的源域图像来自生成的目标域图像 x s → t x^{s\rightarrow t} xs→t 还是来自真正的目标域图像 x t x^t xt。该辅助任务的对抗损失为:
L adv s ~ ( E , D s ) = E x s → t ∼ X s → t [ log D s ( U ( E ( x s → t ) ) ) ] + E x t ∼ X t [ log ( 1 − D s ( U ( E ( x t ) ) ) ) ] \begin{aligned} \mathcal{L}_{\text {adv }}^{\tilde{s}}\left(E, D_{s}\right)=& \mathbb{E}_{x^{s \rightarrow} t \sim X^{s \rightarrow t}}\left[\log D_{s}\left(U\left(E\left(x^{s \rightarrow t}\right)\right)\right)\right]+\\ & \mathbb{E}_{x^{t} \sim X^{t}}\left[\log \left(1-D_{s}\left(U\left(E\left(x^{t}\right)\right)\right)\right)\right] \end{aligned} Ladv s~(E,Ds)=Exs→t∼Xs→t[logDs(U(E(xs→t)))]+Ext∼Xt[log(1−Ds(U(E(xt))))]
在协同学习框架的一个关键是在图像和特征对齐之间共享编码器 E,编码器 E 会通过损失 L a d v s \mathcal{L}_{adv}^s Ladvs 和 L c y c \mathcal{L}_{cyc} Lcyc,以及判别器 D p i , D s D_{p_i},D_s Dpi,Ds 的反向传播来进行优化。
在训练时各个模块的训练顺序为: G t → D t → E → C i → U → D s → D p i G_t\rightarrow D_t \rightarrow E \rightarrow C_i \rightarrow U \rightarrow D_s \rightarrow D_{p_i} Gt→Dt→E→Ci→U→Ds→Dpi。整个网络的目标函数为:
L = L a d v t ( G t , D t ) + λ a d v s L a d v s ( E , U , D s ) + λ g s L c s c ( G t , E , U ) + λ seg 1 L seg 1 ( E , C 1 ) + λ seg 2 L seg 2 ( E , C 2 ) + λ a d v p 1 L a d v p 1 ( E , C , D p 1 ) + λ adv p 2 L a d v p 2 ( E , C , D p 2 ) + λ a d v s ~ L a b s ~ ( E , D s ) \begin{aligned} \mathcal{L}=& \mathcal{L}_{a d v}^{t}\left(G_{t}, D_{t}\right)+\lambda_{a d v}^{s} \mathcal{L}_{a d v}^{s}\left(E, U, D_{s}\right)+\\ & \lambda_{\mathrm{gs}} \mathcal{L}_{\mathrm{csc}}\left(G_{t}, E, U\right)+\lambda_{\operatorname{seg}}^{1} \mathcal{L}_{\operatorname{seg}}^{1}\left(E, C_{1}\right)+\\ & \lambda_{\operatorname{seg}}^{2} \mathcal{L}_{\operatorname{seg}}^{2}\left(E, C_{2}\right)+\lambda_{a d v}^{p_{1}} \mathcal{L}_{a d v}^{p_{1}}\left(E, C, D_{p_{1}}\right)+\\ & \lambda_{\text {adv}}^{p_{2}} \mathcal{L}_{a d v}^{p_{2}}\left(E, C, D_{p_{2}}\right)+\lambda_{a d v}^{\tilde{s}} \mathcal{L}_{a b}^{\tilde{s}}\left(E, D_{s}\right) \end{aligned} L=Ladvt(Gt,Dt)+λadvsLadvs(E,U,Ds)+λgsLcsc(Gt,E,U)+λseg1Lseg1(E,C1)+λseg2Lseg2(E,C2)+λadvp1Ladvp1(E,C,Dp1)+λadvp2Ladvp2(E,C,Dp2)+λadvs~Labs~(E,Ds)
其中 { λ a d v s , λ c y c , λ s e g 1 , λ s e g 2 , λ a d v p 1 , λ a d v p 2 , λ a d v s ~ } \left\{\lambda_{a d v}^{s}, \lambda_{c y c}, \lambda_{s e g}^{1}, \lambda_{s e g}^{2}, \lambda_{a d v}^{p_{1}}, \lambda_{a d v}^{p_{2}}, \lambda_{a d v}^{\tilde{s}}\right\} {λadvs,λcyc,λseg1,λseg2,λadvp1,λadvp2,λadvs~} 是用于平衡各项的参数,在实验时分别设为 { 0.1 , 10 , 1.0 , 0.1 , 0.1 , 0.01 , 0.1 } \{0.1,10,1.0,0.1,0.1,0.01,0.1\} {0.1,10,1.0,0.1,0.1,0.01,0.1}。