34-生成式对抗网络模型综述
作者:张真源
GAN
GAN简介
生成式对抗网络(Generative adversarial networks,GANs)的核心思想源自于零和博弈,包括生成器和判别器两个部分。生成器接收随机变量并生成“假”样本,判别器则用于判断输入的样本是真实的还是合成的。两者通过相互对抗来获得彼此性能的提升。判别器所作的其实就是一个二分类任务,我们可以计算他的损失并进行反向传播求出梯度,从而进行参数更新。
[图片上传失败...(image-6a5bbf-1595332746705)]
GAN的优化目标可以写作:
其中代表了判别器鉴别真实样本的能力,而则代表了生成器欺骗判别器的能力。在实际的训练中,生成器和判别器采取交替训练,即先训练D,然后训练G,不断往复。
WGAN
在上一部分我们给出了GAN的优化目标,这个目标的本质是在最小化生成样本与真实样本之间的JS距离。但是在实验中发现,GAN的训练非常的不稳定,经常会陷入坍缩模式。这是因为,在高维空间中,并不是每个点都可以表示一个样本,而是存在着大量不代表真实信息的无用空间。当两个分布没有重叠时,JS距离不能准确的提供两个分布之间的差异。这样的生成器,很难“捕捉”到低维空间中的真实数据分布。因此,WGAN(Wasserstein GAN)的作者提出了Wasserstein距离(推土机距离)的概念,其公式可以进行如下表示:
这里指的是真实分布和生成分布的联合分布所构成的集合,是从中取得的一个样本。枚举两者之间所有可能的联合分布,计算其中样本间的距离,并取其期望。而Wasserstein距离就是两个分布样本距离期望的下界值。这个简单的改进,使得生成样本在任意位置下都能给生成器带来合适的梯度,从而对参数进行优化。
DCGAN
卷积神经网络近年来取得了耀眼的成绩,展现了其在图像处理领域独特的优势。很自然的会想到,如果将卷积神经网络引入GAN中,是否可以带来效果上的提升呢?DCGAN(Deep Convolutional GANs)在GAN的基础上优化了网络结构,用完全的卷积替代了全连接层,去掉池化层,并采用批标准化(Batch Normalization,BN)等技术,使得网络更容易训练。
[图片上传失败...(image-b4bd04-1595332746705)]
用DCGAN生成图像
为了更方便准确的说明DCGAN的关键环节,这里用一个简化版的模型实例来说明。代码基于pytorch深度学习框架,数据集采用MNIST
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import os
#定义一些超参数
nc = 1 #输入图像的通道数
nz = 100 #输入噪声的维度
num_epochs = 200 #迭代次数
batch_size = 64 #批量大小
sample_dir = 'gan_samples'
# 结果的保存目录
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# 加载MNIST数据集
trans = transforms.Compose([
transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
mnist = torchvision.datasets.MNIST(root=r'G:\VsCode\ml\mnist',
train=True,
transform=trans,
download=False)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
判别器&生成器
判别器使用LeakyReLU作为激活函数,最后经过Sigmoid输出,用于真假二分类
生成器使用ReLU作为激活函数,最后经过tanh将输出映射在[-1,1]之间
# 构建判别器
class Discriminator(nn.Module):
def __init__(self, in_channel=1, num_classes=1):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
# 28 -> 14
nn.Conv2d(nc, 512, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
# 14 -> 7
nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
# 7 -> 4
nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.AvgPool2d(4),
)
self.fc = nn.Sequential(
# reshape input, 128 -> 1
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, x, label=None):
y_ = self.conv(x)
y_ = y_.view(y_.size(0), -1)
y_ = self.fc(y_)
return y_
# 构建生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(nz, 4*4*512),
nn.ReLU(),
)
self.conv = nn.Sequential(
# input: 4 by 4, output: 7 by 7
nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
# input: 7 by 7, output: 14 by 14
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
# input: 14 by 14, output: 28 by 28
nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
nn.Tanh(),
)
def forward(self, x, label=None):
x = x.view(x.size(0), -1)
y_ = self.fc(x)
y_ = y_.view(y_.size(0), 512, 4, 4)
y_ = self.conv(y_)
return y_
训练模型
# 使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
D = Discriminator().to(device)
G = Generator().to(device)
# 损失函数及优化器
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(data_loader):
images = images.to(device)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
#————————————————————训练判别器——————————————————————
#鉴别真实样本
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
#鉴别生成样本
z = torch.randn(batch_size, nz).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
#计算梯度及更新
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
#————————————————————训练生成器——————————————————————
z = torch.randn(batch_size, nz).to(device)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
#计算梯度及更新
reset_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 保存生成图片
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# 保存模型
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
可视化结果
reconsPath = './gan_samples/fake_images-200.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image)
plt.axis('off')
plt.show()
[图片上传失败...(image-7731f-1595332746705)]
cGAN
在之前介绍的几种模型中,我们注意到生成器的输入都是一个随机的噪声。可以认为这个高维噪声向量提供了一些关键信息,而生成器根据自己的理解将这些信息进行补充,最终生成需要的图像。生成器生成图片的过程是完全随机的。例如上述的MNIST数据集,我们不能控制它生成的是哪一个数字。那么,有没有方法可以对其做一定的限制约束,来让生成器生成我们想要的结果呢?cGAN(Conditional Generative Adversarial Nets)通过增一个额外的向量y对生成器进行约束。以MNIST分类为例,限制信息y可以取10维的向量,对于类别进行one-hot编码,并与噪声进行拼接一起输入生成器。同样的,判别器也将原来的输入和y进行拼接。作者通过各种实验证明了这个简单的改进确实可以起到对生成器的约束作用。
[图片上传失败...(image-d88afd-1595332746705)]
判别器&生成器
只需要在前向传播的过程中加入限制变量y,我们很容易就能得到cGAN的生成器和判别器模型
class Discriminator(nn.Module):
...
def forward(self, x, label):
label = label.unsqueeze(2).unsqueeze(3)
label = label.repeat(1, 1, x.size(2), x.size(3))
x = torch.cat(tensors=(x, label), dim=1)
y_ = self.conv(x)
...
class Generator(nn.Module):
...
def forward(self, x, label):
x = x.unsqueeze(2).unsqueeze(3)
label = label.unsqueeze(2).unsqueeze(3)
x = torch.cat(tensors=(x, label), dim=1)
y_ = self.fc(x)
...
Pix2Pix
在上面的cGAN例子中,我们的控制信息取的是想要图像的标签,如果这个控制信息更加的丰富,例如输入一整张图像,那么它能否完成一些更加高级的任务?Pix2Pix(Image-to-Image Translation with Conditional Adversarial Networks)将这一类问题归纳为图像到图像的翻译,其使用改进后的U-net作为生成器,并设计了新颖的Patch-D判别器结构来输出高清的图像。Patch-D是指,不管网络所使用的输入图像有多大,都将其切割成若干个固定大小的Patch,判别器只需对这些Patch的真假进行判断。因为L1损失已经可以衡量生成图像和真实图像的全局差异,所以作者认为判别器只需要用Patch-D这样更关注于局部差异的结构即可。同时Patch-D的结构使得网络的输入变小,减少了计算量并且增大了框架的扩展性。
[图片上传失败...(image-79c5a2-1595332746705)]
CycleGAN
Pix2Pix虽然可以生成高清的图像,但其存在一个致命的缺点:需要相互配对的图片x与y。在现实生活中,这样成对的图片很难或者根本不可能搜集到,这就大大的限制了Pix2Pix的应用。对此,CycleGAN(Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks)提出了不需要配对的图像翻译方法。
[图片上传失败...(image-f9e728-1595332820293)]
CycleGAN其实就是一个X->Y的单向GAN上再加一个Y->X的单向GAN,构成一个“循环”。网络的结构和单次训练过程如下(图片来自于量子位):
[图片上传失败...(image-b8b5eb-1595332820293)]
[图片上传失败...(image-d2a355-1595332820293)]
除了经典的基础GAN损失之外,CycleGAN还引入了Consistency loss的概念。循环一致损失使得X->Y转变的过程中必须保留有X的部分特性。循环损失的公式如下:
两个判别器的损失表示如下:
最后网络的优化目标可以表示为
Pix2Pix以及CycleGAN的官方复现入口:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
StarGAN
Pix2Pix解决了有配对图像的翻译问题,CycleGAN解决了无配对图像的翻译问题,然而他们所作的图像到图像翻译,都是一对一。假设现在需要将人脸转换为喜怒哀乐四个表情,那么他们就需要进行4次不同的训练,这无疑会耗费巨大的计算资源。针对于这个问题,StarGAN(StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation)借助cGAN的思想,在网络输入中加入一个域的控制信息。对于判别器,其不仅需要鉴别样本是否真实,还需要判断输入的图像来自哪个域。StarGAN的训练过程如下:
- 将原始图片c和目标生成域c进行拼接后丢入生成器得到生成图像G(x,c)
- 将生成图像G(x,c)和真实图像y分别丢入判别器D,判别器除了需要判断输入图像的真伪之外,还需要判断它来自哪个域
- 将生成图像G(x,c)和原始生成域c'丢入生成器生成重构图片(为了对生成器生成的图像做进一步的限制,与CycleGAN的重构损失类似)
了解了StarGAN的训练过程,我们很容易得到其损失函数各项的表达形式
首先是GAN的一般损失,这里作者采用了前文所述的WGAN的损失形式:
对于判别器,我们需要鼓励其将输入图像正确的分类到目标域c‘(原始生成域):
对于生成器,我们需要鼓励其成功欺骗判别器将图片分类到目标域c(目标生成域),此外,生成器还需要在以生成图像和原始生成域c'的输入下成功将图像还原回去,这两部分的损失表示如下:
各部分损失乘上自己的权重加总后就构成了判别器和生成器的总损失:
此外,为了更具备通用性,作者还加入了mask vector来适应不同的数据集之间的训练。
总结
名称 | 创新点 |
---|---|
DCGAN | 首次将卷积神经网络引入GAN中 |
cGAN | 通过拼接标签信息来控制生成器的输出 |
Pix2Pix | 提出了一种图像到图像翻译的通用方法 |
CycleGAN | 解决了Pix2Pix需要图像配对的问题 |
StarGAN | 提出了一种一对多的图像到图像的翻译方法 |
InfoGAN | 基于cGAN改进,提出一种无监督的生成方法,适用于不知道图像标签的情况 |
LSGAN | 用最小二乘损失函数代替原始GAN的损失函数,缓解了训练不稳定、生成图像缺乏多样性的问题 |
ProGAN | 在训练期间逐步添加新的高分辨率层,可以生成高分辨率的图像 |
SAGAN | 将注意力机制引入GAN当中,简约高效利用了全局信息 |
本文列举了生成式对抗网络在发展过程中一些具有代表性的网络结构。GANs如今已广泛应用于图像生成、图像去噪、超分辨、文本到图像的翻译等各个领域,且在近几年的研究中涌现了很多优秀的论文。感兴趣的同学可以从下面的链接中pick自己想要了解的GAN~
- THE-GAN-ZOO:汇总了各种GAN的论文及代码地址。
- GAN Timeline:按照时间线对不同的GAN进行了排序。
- Browse state-of-the-art:将ArXiv上的最新论文与GitHub代码相关联,并做了比较排序,涉及了深度学习的各个方面。
参考文献
- Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.
- Arjovsky M, Chintala S, Bottou L. Wasserstein gan[J]. arXiv preprint arXiv:1701.07875, 2017.
- Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.
- Mirza M, Osindero S. Conditional generative adversarial nets[J]. arXiv preprint arXiv:1411.1784, 2014.
- Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 1125-1134.
- Zhu J Y, Park T, Isola P, et al. Unpaired image-to-image translation using cycle-consistent adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2223-2232.
- Choi Y, Choi M, Kim M, et al. Stargan: Unified generative adversarial networks for multi-domain image-to-image translation[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 8789-8797.
- Mao X, Li Q, Xie H, et al. Least squares generative adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2794-2802.
- Karras T, Aila T, Laine S, et al. Progressive growing of gans for improved quality, stability, and variation[J]. arXiv preprint arXiv:1710.10196, 2017.
- Chen X, Duan Y, Houthooft R, et al. Infogan: Interpretable representation learning by information maximizing generative adversarial nets[C]//Advances in neural information processing systems. 2016: 2172-2180.
- Zhang H, Goodfellow I, Metaxas D, et al. Self-attention generative adversarial networks[C]//International Conference on Machine Learning. 2019: 7354-7363.