Generation of 3D Brain MRI Using Auto-Encoding Generative Adversarial Networks论文解读

Generation of 3D Brain MRI Using Auto-Encoding Generative Adversarial Networks

    • 摘要
    • 介绍
    • 方法
      • 模型结构
      • 损失函数
      • 训练流程

  本文出自MICCAI2019。

摘要

  随着深度学习在医学图像分析任务中显示出前所未有的成功,缺乏足够的医学数据正成为一个关键问题。近年来,利用生成对抗网络(GAN)解决有限数据问题的尝试在生成具有多样性的真实图像方面取得了成功,但大部分工作都是基于图像到图像的转换,因此需要来自不同领域的大量数据集。在此,我们提出了一个新的模型,通过学习数据分布,可以成功地从随机向量生成三维脑MRI数据。我们的三维GAN模型结合了变分自动编码器(VAE)和GAN的优点,并引入了额外的编码鉴别网络,利用α-GAN来解决图像模糊和模式崩溃问题。我们也使用带梯度惩罚的Wasserstein-GAN(WGAN-GP)来降低训练的不稳定性。为了证明我们的模型的有效性,我们生成了正常大脑MRI的新图像,并且表明我们的模型在定量和定性测量方面都优于基准模型。我们还训练该模型来合成脑部疾病的MRI数据,以证明我们的模型的广泛适用性。我们的结果表明,所提出的模型可以从一组小的训练数据中成功地生成各种类型和模式的三维全脑MRI影像数据。

介绍

  近年来,深度神经网络,特别是卷积神经网络(CNN)在基于可获得大数据的分类和分割等各种计算机视觉任务中表现出了优异的性能。伴随着这些成果,医学图像分析领域,包括疾病诊断和病变检测领域也取得了显著的突破。然而,训练基于CNN的模型需要大量的医学图像数据,这些数据的获取既费时又费钱。传统的几何变换方法(如翻转、旋转)可以用来增强训练数据,但它们的效果高度依赖于原始数据。
  生成性对抗网络(GAN)是一种积极探索的解决数据缺乏问题的方法,它在计算机视觉领域取得了成功,生成了逼真但与原始图像不同的自然图像。使用GAN的医学图像生成有两种主要方法:图像到图像转换和从随机分布生成图像。前者在各种情况下都得到了广泛的探索,在这种方法中,训练相对容易,因为它是在另一个数据集的指导下完成的,并且生成的图像的质量与真实图像的质量相当。然而,它需要大量的训练数据,并且生成的输出依赖于原始数据的属性,如形状。在后者中,该方法通过学习数据本身的分布,可以生成变化性更强的全新图像。然而,由于稳定训练的困难,以往的工作局限于二维切片的生成,因此很少尝试。
  在本文中,我们提出了一个3D-GAN模型,该模型成功地从随机向量生成了3D脑MRI。我们将α-GAN的结构适应于三维图像生成,通过在现有的生成器和鉴别器之上引入附加的自动编码器和编码鉴别器网络来解决模式崩溃和图像模糊。我们还利用带梯度惩罚的Wasserstein-GAN(WGAN-GP)损失函数来防止不稳定训练。据我们所知,这项工作是首次尝试从随机分布中产生全新的三维脑磁共振成像。为了证明模型的通用性,我们用脑肿瘤和脑卒中病变的MRI训练模型,并证明它可以生成多种类型(如正常或病变)和模式(如T2或FLAIR)的真实三维全脑图像。由于该模型需要少量的数据,因此有望广泛应用于医学图像分析任务,如疾病诊断,特别是对罕见疾病的诊断。

方法

模型结构

  3D生成的主要挑战是模式崩溃问题,其中GAN只产生有限的各种图像,随着任务的复杂性突然从2D增加到3D生成,模式变得更加严重。一种自然的替代方法是使用变分自动编码器(VAE),它不受模式崩溃的影响,但输出具有模糊性。为了有效地解决GAN的模式崩溃问题和VAE的图像模糊问题,作者使用了DeepMind于2017年提出来的结构 (“Variational Approaches for Auto-Encoding Generative Adversarial Networks”),将GAN和VAE结合。 提出了一种编码判别器架构C来代替VAE中的变分推理,因为变分推理和GAN的目标函数有所冲突。编码判别器C将VAE的encoder的输出(后验分布)认为是假的,而将随机噪声向量(先验分布)认为是真的,让编码判别器和encoder也构成一个对抗关系。因此,当编码判别器C无法区分二者时,encoder的输出即后验分布与随机噪声向量即先验分布完美匹配。因此,后验概率可以被完全估计。
Generation of 3D Brain MRI Using Auto-Encoding Generative Adversarial Networks论文解读_第1张图片
  作者采取了α-GAN的网络结构来实现三维生成任务,如上图所示。判别器和编码器网络各有5层的3D卷积层,每一层使用了4x4x4卷积核。因为判别器最后一层的输出是一个值,编码器最后一层的输出是一个向量,输出通道尺寸是据此设定的。在每个卷积层后,使用了Batch Normalization和LeakyReLU。在第一层和最后一层,BatchNorm被移除,以保持输入和输出中各个元素之间的独立性。
  在生成器网络中,作者使用了resize-convolution(“Deconvolution and checkerboard artifacts”)来减少参数的数量和伪影。带有3x3x3卷积核的conventional nearest neighbor upscale(卷积近邻上采样)在卷积层之前使用以取代转置卷积层(反卷积层)。为了训练的稳定性,BatchNorm和ReLU在每个卷积层中被应用,除了最后一层,其中BatchNorm被移除,并且使用了Tanh激活函数。编码判别器C由3个全连接层组成,与判别器相似的是,LeakyReLU和BatchNorm层在每个全连接层之间被放置。

损失函数

  (1)对于生成器和判别器使用了WGAN-GP的损失函数——以Wasserstein距离度量两分布的差距,同时还引入了梯度惩罚项(Gradient Penalty term),通过取代梯度修剪提升效果。
  WGAN:DCGAN用经验告诉我们什么是比较稳定的GAN网络结构, 而WGAN告诉我们: 不用精巧的网络设计和训练过程, 也能训练一个稳定的GAN。
在这里插入图片描述

  WGAN 通过剪裁D网络参数的方式, 对D网络进行稳定更新(Facebook采用了一种名叫“Earth-Mover”的距离(也称Wasserstein距离)来度量分布相似度)。
  但是, 有时一味地通过裁剪D网络weight参数的方式保证训练稳定性, 可能导致生成低质量低清晰度的图片。为了解决WGAN有时生成低质量图片的问题, WGAN-GP舍弃裁剪D网络weights参数的方式, 而是采用裁剪D网络梯度的方式(依据输入数据裁剪), 以下是WGAN-GP的判别器D的Value函数和生成器G的Value函数:
在这里插入图片描述
  (2)编码判别器和编码器分别采用了WGAN-GP相同形式的判别器和生成器的损失函数,编码器之所以没有采取别的形式是因为它实际上是编码判别器的生成器版本。
  (3)在生成器损失函数中,添加了关于重建图像和真实图像的L1损失。
  (4)在判别器损失和编码判别器损失中,分别加入了梯度惩罚项。
  最终的损失函数如下:
在这里插入图片描述

训练流程

  将编码器和生成器视作一个网络,将二者的损失函数合成一个。并且按照如下顺序进行优化:(1)编码器-生成器网络(2)判别器网络(3)编码判别器网络,并且由于生成器的优化速度更慢,在程序的一个step中作者对其更新两次。

你可能感兴趣的:(深度学习)