生成对抗网络( Generative Adversarial Network,GAN)在医疗AI领域尤其是医学成像领域具有重要价值。大多数用于解决医学问题的最新机器学习算法都依赖大型临床和生物医学数据集来有效地训练模型。然而,由于医疗数据的专有性质,数据伦理、保护和患者隐私保密等障碍常常出现,这阻碍了从业者获取这些数据集。此外,某些成像方式获取成本高昂,尤其是当我们处理的疾病较为罕见时。视网膜成像就是一个例子,在这一领域,许多罕见的遗传性疾病的数据集非常有限。GANs有助于扩充这些数据集,使我们能够构建更优秀的机器学习系统来诊断和理解疾病。在这篇文章中,我将简单介绍GANs,虽然会涉及一些数学内容。然后,我将训练一个用于MNIST数字的简单GAN,并讨论训练GANs时遇到的一些挑战。更多细节可以参考Goodfellow的原始论文《生成对抗网络》。
GAN 本质上是由两个互相玩游戏的神经网络组成。 第一个网络称为生成器,第二个网络称为鉴别器。 生成器的工作就像造假者一样——创造一些看起来很真实但实际上不是 100 美元钞票的东西。 另一方面,鉴别器的工作是鉴定专家的工作:检查 100 美元钞票并判断它看起来是否是真品。 生成器和判别器也分别称为“生成”模型和“判别”模型。 从概率的角度来看,如果我们假设世界上所有真实的 100 美元钞票 X X X 和 Y Y Y 的分类标签的概率分布 ∈ { “real” , “fake’ ’ } \in \{\text{``real''}, \text{``fake' '}\} ∈{“real”,“fake’ ’} ,生成模型尝试学习 p ( X ∣ Y ) p(X|Y) p(X∣Y) 或 p ( X , Y ) p(X, Y) p(X,Y) ,而判别模型尝试学习 p ( Y ∣ X ) p(Y|X) p(Y∣X)。
这种模型设置被称为“对抗性”的原因如下:生成器G努力创造一种尽可能接近实际数据分布的合成数据分布,即最小化两个分布之间的距离;而鉴别器D则试图区分真实数据和伪造数据,即最大化分布之间的距离。这在形式上被表示为一个极小极大目标:
min G max D V ( D , G ) \min _G \max _D V(D, G) GminDmaxV(D,G)
其中函数 V ( D , G ) V(D, G) V(D,G) 为原始论文中定义的极小极大损失函数:
V ( D , G ) = E [ log ( D ( X ) ) ] + E [ log ( 1 − D ( G ( z ) ) ) ] V(D, G)=E[\log (D(X))]+E[\log (1-D(G(z)))] V(D,G)=E[log(D(X))]+E[log(1−D(G(z)))]
这个损失函数是什么意思呢?由于 D D D 是一个判别模型, D ( X ) D(X) D(X) 实际上是一个真实图像 X X X 被分类为真实的概率。 G ( z ) G(z) G(z) 是从一些随机噪声 z z z 生成的合成图像,而 D ( G ( z ) ) D(G(z)) D(G(z)) 是合成图像被分类为真实的概率。 1 − D ( G ( z ) ) 1 - D(G(z)) 1−D(G(z)) 则是合成图像被分类为假的概率。 E [ ⋅ ] E[\cdot] E[⋅] 表示所有样本的平均值或期望值。这样一来,极小极大的概念就清晰了!判别器希望最大化真实被分类为真实以及假的被分类为假的概率。生成器则希望最小化假的被分类为假的概率,即第二项——第一项 E [ D ( X ) ] E[D(X)] E[D(X)] 不依赖于 G G G,因此可以忽略不计。
现在我们已经了解了 GAN 的基本组件,接下来我们可以研究如何训练这个模型。 作为示例用例,我将生成手写数字。 用于训练模型的数据集将是著名的 MNIST 数据集,其中包含 60,000 张手写数字的灰度图像。
首先,我将设置基本的生成器架构。该模型采用 64 维随机向量 z \pmb{z} z。 然后,它对向量执行一系列转置卷积运算,最终创建 (28, 28) 灰度图像。
class Generator(nn.Module):
'''
Generator Class
Values:
z_dim: the dimension of the noise vector, a scalar
im_chan: the number of channels of the output image, a scalar
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, z_dim=64, im_chan=1, hidden_dim=64):
super(Generator, self).__init__()
self.z_dim = z_dim
# Build the neural network
self.gen = nn.Sequential(
self.make_gen_block(z_dim, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a generator block of DCGAN,
corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation.
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
'''
# Build the neural block
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(num_features=output_channels),
nn.ReLU()
)
else: # Final Layer
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh()
)
def unsqueeze_noise(self, noise):
'''
Function for completing a forward pass of the generator: Given a noise tensor,
returns a copy of that noise with width and height = 1 and channels = z_dim.
Parameters:
noise: a noise tensor with dimensions (n_samples, z_dim)
'''
return noise.view(len(noise), self.z_dim, 1, 1)
def forward(self, noise):
'''
Function for completing a forward pass of the generator: Given a noise tensor,
returns generated images.
Parameters:
noise: a noise tensor with dimensions (n_samples, z_dim)
'''
x = self.unsqueeze_noise(noise)
return self.gen(x)
接下来,我创建鉴别器模型。 该模型获取一张图像(无论是真实的还是假的),并相应地对其进行分类。 该架构就像任何二元分类器一样——一些卷积对图像进行下采样,然后是完全连接的二元输出层。
class Discriminator(nn.Module):
'''
Discriminator Class
Values:
im_chan: the number of channels of the output image, a scalar
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan=1, hidden_dim=16):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
)
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a discriminator block of DCGAN,
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation.
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: a boolean, true if it is the final layer and false otherwise
(affects activation and batchnorm)
'''
# Build the neural block
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(num_features=output_channels),
nn.LeakyReLU(negative_slope=0.2)
)
else: # Final Layer
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride)
)
'''
Function for completing a forward pass of the discriminator: Given an image tensor,
returns a 1-dimension tensor representing fake/real.
Parameters:
image: a flattened image tensor with dimension (im_dim)
'''
def forward(self, image):
disc_pred = self.disc(image)
return disc_pred.view(len(disc_pred), -1)
现在是主要部分:训练。 原始 GAN 论文中提出的基本算法如下:使用极小极大目标函数,首先通过梯度上升对判别器进行 k 次迭代训练(因为它想要最大化!)。 接下来,生成器通过梯度下降(因为它想要最小化!)进行单次迭代的训练。 这样做的原因是,对于生成器来说,要学习如何从头开始创建图像,它首先需要接收一些“信号”,以激励它提高其生成能力。 因此,首先训练鉴别器将使其比生成器稍微更好地辨别真假图像。 然后,在更新生成器时,该信息会被反向传播,从而使其得以改进。 最后,如果您感兴趣,原始论文给出了该算法将收敛到真实数据分布的数学证明。
下面是执行该算法的代码。 为简单起见,我交替更新判别器和生成器(即 k = 1 k=1 k=1)。 另外,请注意,极小极大损失在数学上计算为真实图像和生成图像的二元交叉熵损失的平均值!
# training requirements
device='cuda'
n_epochs = 50
criterion = nn.BCEwithLogitsLoss()
batch_size = 128
# generator and optimizer
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
# prepare discriminator and optimizer
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))
# stores losses per epoch
generator_losses = []
discriminator_losses = []
for epoch in range(n_epochs):
# stores losses averaged over all batches
mean_gen_loss = 0
mean_disc_loss = 0
# Dataloader returns the batches
for real, _ in tqdm(dataloader):
cur_batch_size = len(real)
real = real.to(device)
## Update discriminator ##
disc_opt.zero_grad()
fake_noise = torch.rand(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
disc_fake_pred = disc(fake.detach())
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_pred = disc(real)
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
# Keep track of the average discriminator loss
mean_dis_loss += disc_loss.item() / cur_batch_size
# Update gradients
disc_loss.backward(retain_graph=True)
# Update optimizer
disc_opt.step()
## Update generator ##
gen_opt.zero_grad()
fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
fake_2 = gen(fake_noise_2)
disc_fake_pred = disc(fake_2)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()
# Keep track of the average generator loss
mean_gen_loss += gen_loss.item() / cur_batch_size
generator_losses.append(mean_gen_loss)
discriminator_losses.append(mean_disc_loss)
经过训练后,我得到了以下图像生成结果。虽然有些图像看起来不太清晰,但总的来说效果还是相当不错的!
尽管生成对抗网络(GAN)的整体设置很优雅,但实际的训练过程却并非如此!我执行的手写数字生成任务相对简单,而一般的图像数据集则包含更为复杂的详细特征,难以生成。
另外,关于模型训练,我们现在同时优化两个神经网络,仅观察损失函数是不够的——还必须确保两个网络中没有一个过于强大。直观地想,如果鉴别器变得太好,以至于能识别出任何假图像,那么生成器可能会想,“如果我总是被发现,那我还提升什么。”但如果生成器变得太强大,轻易地欺骗了鉴别器,这种情况对生成器也不利,因为它会认为,“我甚至不用太努力,因为鉴别器很差!”实际上,前者情况更有可能发生,因为生成逼真样本比将其分类为真或假要复杂得多——这意味着鉴别器更有可能赢得极小极大游戏。从计算上讲,这种情况可能导致训练发散和梯度消失等后续问题,这是我们希望避免的。
GANs的另一个常见问题是模式坍塌——即生成器找到了欺骗鉴别器的“作弊码”并不断使用它来赢得游戏。例如,如果我们有一个多模态分布——手写数字的分布对于“1”、“2”、“3”等都有一个模式。生成“1”可能比其他数字容易得多,因此在训练过程中,生成器会一直创造“1”,欺骗鉴别器。我们不希望这种情况发生,因为我们期望GAN生成多样化的数据样本。
最后,像条件生成(即生成特定类别的图像)、可控生成(即创造带有特定细节的图像)等数据生成的其他方面,在基本的GAN中并不可用。在过去五年中,GANs研究领域取得了突破性进展,许多论文讨论了更好的设计和训练GANs的方法以实现这些目标。请查看GAN动物园,这是一个包含迄今为止开发的各种类型GANs的大型仓库。
不管怎样,您现在已经了解了生成对抗网络的基础知识,恭喜!在未来的帖子中,我希望深入探讨GANs的细节,讨论不同的最新模型,以及最重要的是,医疗领域的应用。敬请期待!