原文:Generative Adversarial Networks (arxiv.org)
以前的一些模型是通过显示训练数据的分布,比如上面的VAE。也就是可以得到最终训练好的均值与方差,而GAN是隐式训练。
GAN受博弈论中的零和博弈启发,将生成问题视作判别器和生成器这两个网络的对抗和博弈:生成器从给定噪声中(一般是指均匀分布或者正态分布)产生合成数据,判别器分辨生成器的的输出和真实数据。前者试图产生更接近真实的数据,后者试图更完美地分辨真实数据与生成数据。由此,两个网络在对抗中进步,在进步后继续对抗,由生成网络得到的数据也就越来越完美,逼近真实数据,从而可以生成想要得到的数据(图片、序列、视频等)。
图1 GAN模型体系结构图示
设真实训练数据的分布服从 x ∼ p data x \sim p_\text{data} x∼pdata,先验噪声输入变量分布服从 z ∼ q z ( z ) z \sim q_z(z) z∼qz(z), 生成模型为 G ( z ; θ g ) G(z;\theta_g) G(z;θg),其中 θ g \theta_g θg是生成模型中的网络参数。设判别模型为 D ( z ; θ d ) D(z;\theta_d) D(z;θd),其中 θ d \theta_d θd是判别模型中的网络参数,生成模型通过学习数据 x x x得到生成模型分布 p g p_g pg。根据理论分析结果,如果给予生成模型足够的承载力和足够的训练时间,生成模型 G G G完全有可能拟合出类似于真实样本分布 p data p_\text{data} pdata的从假设随机分布 p z p_z pz得到的合成样本分布 p g p_g pg。 同时,当模型训练达到全局最优平衡点 p g = p data p_g=p_\text{data} pg=pdata的时候,说明生成对抗网络达到了 博弈的最佳状态。数学优化模型如下所示:
min G max D V ( D , G ) = E x ∼ p data ( x ) [ log ( D ( x ) ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( x ) ) ] \min_{\ G} \max_{\ D} V(D, G) = {E}_{x\sim p_{\text{data}}(x)}[\log(D(x))] +{E}_{z\sim p_z(z)}[\log(1 - D(G(x))] Gmin DmaxV(D,G)=Ex∼pdata(x)[log(D(x))]+Ez∼pz(z)[log(1−D(G(x))]
min G E z ∼ p z ( z ) [ log ( 1 − D ( G ( x ) ) ] \min_{\ G}E_{z\sim p_z(z)}[\log(1 - D(G(x))] GminEz∼pz(z)[log(1−D(G(x))]
max D E x ∼ p data ( x ) [ log ( D ( x ) ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( x ) ) ] \max_{\ D}E_{x\sim p_{\text{data}}(x)}[\log(D(x))] + E_{z\sim p_z(z)}[\log(1 - D(G(x))] DmaxEx∼pdata(x)[log(D(x))]+Ez∼pz(z)[log(1−D(G(x))]
由于判别网络和生成网络得到的极值点可能并不是在同一个位置,那么便存在一个全局的平衡点,也就是纳什(Nash)均衡点。
由上式可知,给定任意生成器 G G G,判别器 D D D 的训练标准都是最大化 V ( D , G ) V(D, G) V(D,G),那么, 第一条公式还可以表示为:
V ( D , G ) = ∫ x P data ( x ) log ( D ( x ) ) d x + ∫ x p z ( z ) log ( 1 − D ( g ( z ) ) ) d z = ∫ x [ P data ( x ) log ( D ( x ) ) + P g ( x ) log ( 1 − D ( x ) ) ] d x \begin{aligned} V(D, G) & =\int_x P_{\text {data }}(x) \log (D(x)) d x+\int_x p_z(z) \log (1-D(g(z))) d z \\ & =\int_x\left[P_{\text {data }}(x) \log (D(x))+P_g(x) \log (1-D(x))\right] d x \end{aligned} V(D,G)=∫xPdata (x)log(D(x))dx+∫xpz(z)log(1−D(g(z)))dz=∫x[Pdata (x)log(D(x))+Pg(x)log(1−D(x))]dx
已知形如 alog ( y ) + b log ( 1 − y ) \operatorname{alog}(y)+b \log (1-y) alog(y)+blog(1−y) 的函数, 当 y ∈ [ 0 , 1 ] y \in[0,1] y∈[0,1] 时, a , b a, b a,b 为非零实数时, 可以在 a a + b \frac{a}{a+b} a+ba 处取得最小值,所以当生成器网络的参数固定不变训练判别器时,判别器网络的当前全局最优值可以通过下面公式获得:
D G ∗ ( x ) = P data ( x ) P data ( x ) + P g ( x ) D_G^*(x)=\frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_g(x)} DG∗(x)=Pdata (x)+Pg(x)Pdata (x)
判别器最优时训练函数表示如下:
max D V ( D , G ) = E x ∼ P data ( x ) [ log ( D G ∗ ( x ) ) ] + E z ∼ P z ( z ) [ log ( 1 − D G ∗ ( G ( z ) ) ) ] = E x ∼ P data ( x ) [ log ( D G ∗ ( x ) ) ] + E x ∼ P g ( x ) [ log ( 1 − D G ∗ ( x ) ) ] = E x ∼ P data ( x ) [ log P data ( x ) P data ( x ) + P g ( x ) ] + E x ∼ P g ( x ) [ log P data ( x ) P data ( x ) + P g ( x ) ] \begin{aligned} \max _D V(D, G) & =E_{x \sim P_{\text {data }}(x)}\left[\log \left(D_G^*(x)\right)\right]+E_{z \sim P_z(z)}\left[\log \left(1-D_G^*(G(z))\right)\right] \\ & =E_{x \sim P_{\text {data }}(x)}\left[\log \left(D_G^*(x)\right)\right]+E_{x \sim P_g(x)}\left[\log \left(1-D_G^*(x)\right)\right] \\ & =E_{x \sim P_{\text {data }}(x)}\left[\log \frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_g(x)}\right]+E_{x \sim P_g(x)}\left[\log \frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_g(x)}\right] \end{aligned} DmaxV(D,G)=Ex∼Pdata (x)[log(DG∗(x))]+Ez∼Pz(z)[log(1−DG∗(G(z)))]=Ex∼Pdata (x)[log(DG∗(x))]+Ex∼Pg(x)[log(1−DG∗(x))]=Ex∼Pdata (x)[logPdata (x)+Pg(x)Pdata (x)]+Ex∼Pg(x)[logPdata (x)+Pg(x)Pdata (x)]
当且仅当 p g = p data p_g=p_\text{data} pg=pdata时,此时的真实样本数 据与生成数据的分布相同, V ( D , G ) V(D, G) V(D,G)取得最小值 − l o g 4 -log4 −log4,这时生成器就能生成无法同真实样本区分的“真实”数据,判别器已经无法再区分它们之间的差别。
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
# 创建保存生成图片的目录
os.makedirs("images", exist_ok=True)
# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="训练的轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每批的样本数量")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的梯度一阶矩的衰减系数")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的梯度二阶矩的衰减系数")
parser.add_argument("--n_cpu", type=int, default=8, help="用于生成批次的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--img_size", type=int, default=28, help="每个图像维度的大小")
parser.add_argument("--channels", type=int, default=1, help="图像的通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="生成图像样本的间隔")
opt = parser.parse_args()
print(opt)
# 定义图像形状
img_shape = (opt.channels, opt.img_size, opt.img_size)
# 检查是否支持CUDA加速
cuda = True if torch.cuda.is_available() else False
# 定义生成器模型
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
# 定义鉴别器模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# 定义对抗网络的损失函数
adversarial_loss = torch.nn.BCELoss()
# 初始化生成器和鉴别器
generator = Generator()
discriminator = Discriminator()
# 如果支持CUDA,则将模型移至GPU
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# 定义优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# 定义张量类型
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# 训练
# ----------
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# 定义对抗网络的真实标签和生成标签
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
# 配置输入
real_imgs = Variable(imgs.type(Tensor))
# -----------------
# 训练生成器
# -----------------
optimizer_G.zero_grad()
# 生成器输入的噪声样本
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# 生成一批图像
gen_imgs = generator(z)
# 生成器的损失,度量其欺骗鉴别器的能力
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# 训练鉴别器
# ---------------------
optimizer_D.zero_grad()
# 鉴别器的损失,度量其区分真实图像和生成图像的能力
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
# 每隔一定间隔保存生成的图像样本
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25