生成器的本质就是一个神经网络,通过从一个简单的分布 z z z中采样来生成一个复杂的分布,同时要求这个生成的复杂分布尽可能地相似输入样本 x x x(真实数据)的分布。其实现了从低维数据到高维数据的转变。
那么生成器为什么输出的非要是分布呢?这是因为在实际训练中同一个输入,生成器可能会产生不同的输出,但是每一种输出单独出现的时候都是正确且合理的。比如下图是一个预判小精灵在下一秒走向的问题:如果我们的神经网络的训练数据集只有向右转的情况,那么输出就是向右转;如果我们的神经网络的训练数据集只有向左转的情况,那么输出就是向左转;如果我们的神经网络同时包含向右转和向左转的训练数据集,且向右转和向左转的误差都比较小,那么就会输出同时向左向右转的情况。而如果让机器的输出不再是一个单一的输出,而是一个机率的分布就可以解决这个问题。如果我们给神经网络增加了一个二分类分布(0,1各占50%),让0表示向右转,1表示向左转。
其实,这样做可以实现同样的输入有多种不同可能的输出。
GAN包含了两个模型:
算法流程:
初始化生成器和判别器 G 、 D G、D G、D。
在每一个训练迭代中
固定生成器 G G G,然后更新判别器 D D D**。**详细来说,就是首先我们从数据集中抽样出一部分样本,同时还会使用生成器 G G G生成相同数量的样本,分别将其标记为1、0。然后就是让判别器 D D D去学习这两类样本,使其能区分出真假样本。(这里只有判别器的梯度会更新)
固定判别器 D D D,然后更新生成器 G G G**。**详细来说,就是这里我们其实会想把生成器和判别器组成一个大的神经网络模型。然后从简单分布中采样出一个随机编码向量让它进入到生成其中去生成一个样本,然后将这个生成的样本喂入到判别器中,让其打分。最后根据这个分数来更新生成器的梯度。
总体流程如下图所示
如下图所示,生成器要做的事情就是输入一个低维的简单分布,输出一个高位的复杂分布 P G P_G PG,同时要做到生成的复杂分布要尽量与数据 d a t a data data的分布 P d a t a P_{data} Pdata越接近越好(即理想情况下,就是让 P G P_G PG和 P d a t a P_{data} Pdata一模一样)。
比如,我们现在输入是一个标准正态分布(其样本主要分布在中间),通过生成器 G G G之后,输出 P G P_G PG(样本分布在两端,且主要分布在左侧),其分布和目标分布 P d a t a P_{data} Pdata还是有一定差距的,那我们怎么让 P G P_G PG和 P d a t a P_{data} Pdata更接近呢?因为生成器本身就是一个神经网络,所以我们要做的就是设置一个合适的损失函数,根据我们的目标,我们可得:
G ∗ = a r g min G D i v ( P G , P d a t a ) ( 1 ) G^*=arg\min\limits_{G}Div(P_G,P_{data})\ \ \ (1) G∗=argGminDiv(PG,Pdata) (1)
其中, D i v Div Div表示的是分布 P G P_G PG和 P d a t a P_{data} Pdata之间的距离。而谈到距离,我们能想到的可能有KL divergence、JS divergence等,但是我们现在还有一个问题是其实我们并不知道 P G P_G PG和 P d a t a P_{data} Pdata真实的样子是怎样的?而解决这个问题的方法就是下文要提到的判别器 D D D。
如下图所示,我们首先假设蓝色星星是从真实样本分布中采样出来的样本,黄色星星是从模拟样本分布中采样出来的样本;目前两个样本是融合在一起的,而判别器要做的工作就是给真实的样本打一个较高的分数,给生成器模拟出来的样本打一个较低的分数,即最大化真实样本的分数,最小化模拟样本的分数。那么该判别器的损失函数为:
D ∗ = a r g max D V ( D , G ) V ( D , G ) = E y ∼ P d a t a [ l o g D ( y ) ] + E y ∼ P G [ l o g ( 1 − D ( y ) ) ] ( 2 ) \begin{aligned} D^*&=arg\max\limits_{D}V(D,G)\\ V(D,G)&=E_{y\sim P_{data}}[logD(y)]+E_{y\sim P_G}[log(1-D(y))] \end{aligned}\ \ \ (2) D∗V(D,G)=argDmaxV(D,G)=Ey∼Pdata[logD(y)]+Ey∼PG[log(1−D(y))] (2)
从这里我们还是看不出判别器损失函数 D ∗ D^* D∗和生成器损失函数 G ∗ G^* G∗之间的关系的,并且到这里似乎还看不出任何与距离相关的概念,其实只会再推导就能看出来了。
首先我们已知 V ( G , D ) = ∫ y ( P d a t a ( y ) log ( D ( y ) ) + P G ( y ) log ( 1 − D ( y ) ) ) d y V(G, D) = \int_y \bigg( P_{data}(y) \log(D(y)) + P_G (y) \log(1 - D(y)) \bigg) dy V(G,D)=∫y(Pdata(y)log(D(y))+PG(y)log(1−D(y)))dy,现在我们做一些简单的替换,假设 y ~ = D ( y ) , A = P d a t a ( y ) , B = P G ( y ) \tilde{y} = D(y), A=P_{data}(y), B=P_G(y) y~=D(y),A=Pdata(y),B=PG(y),则 V ( G , D ) = ∫ y ( A log y ~ + B log ( 1 − y ~ ) ) d y V(G, D) = \int_y \bigg( A\log\tilde{y} + B \log(1 - \tilde{y}) \bigg) dy V(G,D)=∫y(Alogy~+Blog(1−y~))dy,同时因为 y y y是从全体可能的样本中采样出来的,其实这是积分是可以省略的,那么就有了
f ( y ~ ) = A l o g y ~ + B l o g ( 1 − y ~ ) d f ( y ) ~ d y ~ = A 1 y ~ + B ⋅ 1 1 − y ~ ⋅ ( − 1 ) = A − ( A + B ) y ~ y ~ ( 1 − y ~ ) \begin{aligned} f(\tilde{y})&=Alog\tilde{y}+Blog(1-\tilde{y})\\ \frac{df(\tilde{y)}}{d\tilde{y}}&=A\frac{1}{\tilde{y}}+B\cdot\frac{1}{1-\tilde{y}}\cdot(-1)\\ &=\frac{A-(A+B)\tilde{y}}{\tilde{y}(1-\tilde{y})} \end{aligned} f(y~)dy~df(y)~=Alogy~+Blog(1−y~)=Ay~1+B⋅1−y~1⋅(−1)=y~(1−y~)A−(A+B)y~
然后,我们令 d f ( y ~ ) d y ~ = 0 \frac{df(\tilde{y})}{d\tilde{y}}=0 dy~df(y~)=0。对于上式,我们只能令 A − ( A + B ) y ~ = 0 A-(A+B)\tilde{y}=0 A−(A+B)y~=0,那么有
y ~ ∗ = D ∗ ( y ) = A A + B = P d a t a P d a t a + P G ∈ [ 0 , 1 ] \tilde{y}^*=D^*(y)=\frac{A}{A+B}=\frac{P_{data}}{P_{data}+P_G}\in[0,1] y~∗=D∗(y)=A+BA=Pdata+PGPdata∈[0,1]
当生成器足够好时,即生成器模拟生成的样本和数据样本完全相同 P d a t a = P G P_{data}=P_G Pdata=PG,有 D ∗ ( y ) = 1 2 D^*(y)=\frac{1}{2} D∗(y)=21。
因为 D ∗ ( y ) = P d a t a P d a t a + P G D^*(y)=\frac{P_{data}}{P_{data}+P_G} D∗(y)=Pdata+PGPdata,所以我们有
V ( G , D ∗ ) = ∫ y ( P d a t a ( y ) log ( D ∗ ( y ) ) + P G ( y ) log ( 1 − D ∗ ( y ) ) ) d y = ∫ y ( P d a t a ( y ) log ( P d a t a P d a t a + P G ) + P G ( y ) log ( 1 − P d a t a P d a t a + P G ) ) d y = ∫ y ( P d a t a ( y ) log ( P d a t a P d a t a + P G ) + P G ( y ) log ( P G P d a t a + P G ) ) d y \begin{aligned} V(G, D^*) &= \int_y \bigg( P_{data}(y) \log(D^*(y)) + P_G (y) \log(1 - D^*(y)) \bigg) dy\\ &=\int_y \bigg( P_{data}(y) \log(\frac{P_{data}}{P_{data}+P_G}) + P_G (y) \log(1 - \frac{P_{data}}{P_{data}+P_G}) \bigg) dy\\ &=\int_y \bigg( P_{data}(y) \log(\frac{P_{data}}{P_{data}+P_G}) + P_G (y) \log(\frac{P_{G}}{P_{data}+P_G}) \bigg) dy \end{aligned} V(G,D∗)=∫y(Pdata(y)log(D∗(y))+PG(y)log(1−D∗(y)))dy=∫y(Pdata(y)log(Pdata+PGPdata)+PG(y)log(1−Pdata+PGPdata))dy=∫y(Pdata(y)log(Pdata+PGPdata)+PG(y)log(Pdata+PGPG))dy
现在我们来计算一下分布 P d a t a P_{data} Pdata和 P G P_G PG之间的的JS距离:
D J S ( P d a t a ∣ ∣ P G ) = 1 2 D K L ( P d a t a ∣ ∣ P d a t a + P G 2 ) + 1 2 D K L ( P G ∣ ∣ P d a t a + P G 2 ) = 1 2 ∫ P d a t a l o g 2 P d a t a P d a t a + P G d x + 1 2 ∫ P G l o g 2 P G P d a t a + P G d x = 1 2 ( l o g 2 + ∫ P d a t a l o g P d a t a P d a t a + P G d x ) + 1 2 ( l o g 2 + ∫ P G l o g P G P d a t a + P G d x ) = 1 2 ( l o g 4 + ∫ y ( P d a t a ( y ) log ( P d a t a P d a t a + P G ) + P G ( y ) log ( P G P d a t a + P G ) ) d y ) = 1 2 ( l o g 4 + V ( G , D ∗ ) ) \begin{aligned} D_{JS}(P_{data}||P_G)&=\frac{1}{2}D_{KL}(P_{data}||\frac{P_{data}+P_G}{2})+\frac{1}{2}D_{KL}(P_G||\frac{P_{data}+P_G}{2})\\ &=\frac{1}{2}\int P_{data}log\frac{2P_{data}}{P_{data}+P_G}dx+\frac{1}{2}\int P_{G}log\frac{2P_{G}}{P_{data}+P_G}dx\\ &=\frac{1}{2}(log2+\int P_{data}log\frac{P_{data}}{P_{data}+P_G}dx)+\frac{1}{2}(log2+\int P_{G}log\frac{P_{G}}{P_{data}+P_G}dx)\\ &=\frac{1}{2}(log4+\int_y \bigg( P_{data}(y) \log(\frac{P_{data}}{P_{data}+P_G}) + P_G (y) \log(\frac{P_{G}}{P_{data}+P_G}) \bigg) dy)\\ &=\frac{1}{2}(log4+V(G,D^*)) \end{aligned} DJS(Pdata∣∣PG)=21DKL(Pdata∣∣2Pdata+PG)+21DKL(PG∣∣2Pdata+PG)=21∫PdatalogPdata+PG2Pdatadx+21∫PGlogPdata+PG2PGdx=21(log2+∫PdatalogPdata+PGPdatadx)+21(log2+∫PGlogPdata+PGPGdx)=21(log4+∫y(Pdata(y)log(Pdata+PGPdata)+PG(y)log(Pdata+PGPG))dy)=21(log4+V(G,D∗))
因此
V ( G , D ∗ ) = 2 D J S ( P d a t a ∥ P G ) − 2 log 2 V(G, D^*) = 2D_{JS}(P_{data} \| P_G) - 2\log2 V(G,D∗)=2DJS(Pdata∥PG)−2log2
从这里我们就能看出 V ( G , D ∗ ) V(G, D^*) V(G,D∗)是和JS距离有关系的,那生成器从一开始的目标不就是想要最小化分布 P d a t a P_{data} Pdata和 P G P_G PG之间的距离吗?所以我们可以直接用 max D V ( G , D ) \max\limits_{D}V(G, D) DmaxV(G,D)来替换 D i v ( P G , P d a t a ) Div(P_G,P_{data}) Div(PG,Pdata),那么最终生成器的损失函数为:
G ∗ = a r g min G max D V ( G , D ) ( 3 ) G^*=arg\min\limits_{G}\max\limits_{D}V(G, D)\ \ \ (3) G∗=argGminDmaxV(G,D) (3)
GAN的逻辑如下图所示:
在上一小节,我们已经得出了生成器与判别器各自的损失函数,求证出 V ( G , D ∗ ) V(G, D^*) V(G,D∗)是和JS距离有关系的。但是JS本身仍会存在着一些问题,如下图所示。
正如上图所示,但两个分布没有重叠时,JS距离总是为 l o g 2 log2 log2。这会导致虽然随着训练, P G P_G PG和 P d a t a P_{data} Pdata越来越接近(但是仍没有重合),而JS距离认为这个结果是一样坏的。为了解决这个问题,就有人提出了EM距离(推土机距离)。
EM距离之所以称为推土机距离,是因为它的原理就好比是把两个土堆推向一起,且其度量的就是这两个土堆的距离。例子如下图所示。
如上图所示,如果我们想让分布 P P P像分布 Q Q Q一样,那么我可以像上图左侧那样,但是上图右侧那样也是可以的。而这会带来很多种可能的“移动方案”,那么应该用哪一种方案合适呢?Wasserstein距离就将所有的”移动方案“穷举出来,然后将最小的”移动方案“作为W距离。
那么Wassertein距离有什么好处呢?如下图所示,与JS距离相比,W距离可以体现出每次分布的变化,随着两个分布距离的越来越近(没有重合),其值会越来越小。
而在WGAN中,分布 P d a t a P_{data} Pdata和 P G P_G PG的Wasserstein距离被定义为:
max D ∈ 1 − L i p s c h i t z { E y ∼ P d a t a [ D ( y ) ] − E y ∼ P G [ D ( y ) ] } \max\limits_{D\in 1-Lipschitz}\{E_{y\sim P_{data}}[D(y)]-E_{y\sim P_G}[D(y)]\} D∈1−Lipschitzmax{Ey∼Pdata[D(y)]−Ey∼PG[D(y)]}
第一眼看上去,这好像和GAN没什么区别似的,好像都是最大化真实分布的分数,最小化模拟分布的分数。但是仔细看的话,我们会注意到其条件发生了变化,上述过程要在 D ∈ 1 − L i p s c h i t z D\in1-Lipschitz D∈1−Lipschitz的条件下,那什么是Lipschitz
呢?简单来说,就是使 D D D的变换足够平滑,不能过于剧烈。
如上图所示,当没有约束时,判别器 D D D可能会使的真实数据的分数无限大,使得模拟数据的分数无效小,而这导致的一个直接问题就是判别器 D D D永远也无法收敛。但是如果有Lipschitz
条件存在,判别器 D D D就不会打出无限大或无限小的分数,两个分布之间的分数虽然存在差距,但不会很大,这样判别器 D D D最终仍会收敛。
而原始的WGAN虽然提出了上述目标函数,但是其处理手段非常粗糙。其主要是迫使梯度参数 W W W一直处于 [ − c , c ] [-c,c] [−c,c]之间,即如果 W > c , W = c W>c,W=c W>c,W=c,如果 W < − c , W = − c W<-c,W=-c W<−c,W=−c。
这一小节,同自编码器一样,都是对Fashion_MNIST数据集进行简单的处理。
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.autograd as autograd
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib
# 引入本地代码库
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
def imshow(img, filename=None):
npimg = img.numpy()
plt.axis('off')
array = np.transpose(npimg, (1, 2, 0))
if filename != None:
matplotlib.image.imsave(filename, array)
else:
plt.imshow(array)
plt.show()
img_transform = transforms.Compose([
transforms.ToTensor(),
# 使图片的值符合均值为0.5, 方差为0.5的正态分布
transforms.Normalize(mean=[0.5], std=[0.5])
])
data_dir = './fashion_mnist/'
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True,
transform=img_transform, download=True)
val_dataset = torchvision.datasets.FashionMNIST(data_dir, train=False,
transform=img_transform)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=10, shuffle=False)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
class WGAN_D(nn.Module):
def __init__(self, inputch=1):
super(WGAN_D, self).__init__()
# 先设置两个卷积层,对图片进行下采样
# 传入的图片shape为1024, 1, 28, 28
self.conv1 = nn.Sequential(
nn.Conv2d(inputch, 64, 4, 2, 1), # batch, 64, 14, 14
nn.LeakyReLU(0.2, True),
nn.InstanceNorm2d(64, affine=True))
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128,4, 2, 1), # batch, 128, 7, 7
nn.LeakyReLU(0.2, True),
nn.InstanceNorm2d(128, affine=True) )
self.fc = nn.Sequential(
nn.Linear(128*7*7, 1024),
nn.LeakyReLU(0.2, True))
self.fc2 = nn.Sequential(
nn.InstanceNorm1d(1, affine=True), # 表明图片通道为1
nn.Flatten(),
nn.Linear(1024, 1))
def forward(self, x, * arg):
x = self.conv1(x)
x = self.conv2(x) # batch, 128, 7, 7
x = x.view(x.size(0), -1) # batch, 128*7*7
x = self.fc(x) # batch, 1024
x = x.reshape(x.size(0), 1, -1) # batch, 1, 1024
x = self.fc2(x) # batch * 1 * 1
# x.view(-1, 1).shape: batch, 1
# x.view(-1, 1).squeeze(1).shape: batch
return x.view(-1, 1).squeeze(1)
class WGAN_G(nn.Module):
def __init__(self, input_size, input_n=1):
super(WGAN_G, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size*input_n, 1024),
nn.ReLU(True),
nn.BatchNorm1d(1024)
)
self.fc2 = nn.Sequential(
nn.Linear(1024, 7*7*128),
nn.ReLU(True),
nn.BatchNorm1d(7*7*128)
)
# 使用全卷积还原到图片原来的大小
self.upsample1 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False),
nn.ReLU(True),
nn.BatchNorm2d(64)
)
self.upsample2 = nn.Sequential(
nn.ConvTranspose2d(64, 1, 4, 2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, x, *arg):
x = self.fc1(x)
x = self.fc2(x)
x = x.view(x.size(0), 128, 7, 7)
x = self.upsample1(x)
img = self.upsample2(x)
return img
第6和10行定义了两个卷积层,其使用的是窄卷积。窄卷积的原理是如果图片的维度是 N 1 × N 1 N_1\times N_1 N1×N1,卷积核的大小为 N 2 × N 2 N_2\times N_2 N2×N2,步长为 S S S,那么经过卷积后的图片维度为 N 1 − N 2 S + 1 × N 1 − N 2 S + 1 \frac{N_1-N_2}{S}+1\times \frac{N_1-N_2}{S}+1 SN1−N2+1×SN1−N2+1。以第6行的卷积层为例,图片大小为 28 × 28 28\times 28 28×28,卷积核大小为 4 × 4 4 \times 4 4×4,步长为2, padding为1,则经过这层卷积层之后的图片大小为 ( 28 + 2 ) − 4 2 + 1 = 14 \frac{(28+2)-4}{2}+1=14 2(28+2)−4+1=14,即 14 × 14 14 \times 14 14×14。
第48和53行定义了两个卷积层,其实用的是全卷积。全卷积的原理是如果图片的维度是 N 1 × N 1 N_1\times N_1 N1×N1,卷积核的大小为 N 2 × N 2 N_2\times N_2 N2×N2,步长为 S S S,padding为n,那么经过卷积后的图片维度是 ( N 1 + 2 n ) × S − N 2 (N_1+2n)\times S -N_2 (N1+2n)×S−N2 。举个例子,如果图片大小为 7 × 7 7\times 7 7×7,卷积核大小为 4 × 4 4 \times 4 4×4,步长为2, padding为1,则经过这层卷积层之后的图片大小为 ( 7 + 2 × 1 ) × 2 − 4 = 14 (7+2\times 1) \times 2 -4=14 (7+2×1)×2−4=14,即 14 × 14 14 \times 14 14×14。
# Loss weight for gradient penalty
lambda_gp = 10
def compute_gradient_penalty(D, real_samples, fake_samples, y_one_hot):
eps = torch.FloatTensor(real_samples.size(0), 1, 1, 1).uniform_(0, 1).to(device)
X_inter = (eps * real_samples +((1 - eps) * fake_samples)).requires_grad_(True)
d_interpolates = D(X_inter, y_one_hot)
fake = torch.full((real_samples.size(0),), 1, device=device)
# d_interpolates 对 X_inter求偏导
gradients = autograd.grad( outputs=d_interpolates,
inputs=X_inter,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean()*lambda_gp
return gradient_penaltys
torch.autograd.grad()
函数简介:outputs
参数是被求导的参数,inputs
是求导的参数,当outputs
是一个标量时grad_outputs=None
;当outputs
是一个向量时,需要为grad_outputs
指定一个值。
import torch
x = torch.randn(3,4).requires_grad_(True)
y = torch.sum(x)
gradient = torch.autograd.grad(y, x)[0]
print(gradient)
'''
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
'''
x_1 = torch.randn(4,).requires_grad_(True)
y_1 = torch.sum(x_1)
gradient1 = torch.autograd.grad(y_1, x_1)[0]
print(gradient1)
'''
tensor([1., 1., 1., 1.])
'''
在代码的第15行我们定义了一个长度为4的向量 x = [ x 1 , x 2 , x 3 , x 4 ] x=[x_1, x_2, x_3,x_4] x=[x1,x2,x3,x4],第16行 y = x 1 + x 2 + x 3 + x 4 y=x_1+x_2+x_3+x_4 y=x1+x2+x3+x4,使用torch.autograd.grad()
的过程为:
δ y δ x 1 = δ ( x 1 + x 2 + x 3 + x 4 ) x 1 = 1 δ y δ x 2 = δ ( x 1 + x 2 + x 3 + x 4 ) x 2 = 1 δ y δ x 3 = δ ( x 1 + x 2 + x 3 + x 4 ) x 3 = 1 δ y δ x 4 = δ ( x 1 + x 2 + x 3 + x 4 ) x 4 = 1 \begin{aligned} \frac{\delta y}{\delta x_1}=\frac{\delta(x_1+x_2+x_3+x_4)}{x_1}=1\\ \frac{\delta y}{\delta x_2}=\frac{\delta(x_1+x_2+x_3+x_4)}{x_2}=1\\ \frac{\delta y}{\delta x_3}=\frac{\delta(x_1+x_2+x_3+x_4)}{x_3}=1\\ \frac{\delta y}{\delta x_4}=\frac{\delta(x_1+x_2+x_3+x_4)}{x_4}=1 \end{aligned} δx1δy=x1δ(x1+x2+x3+x4)=1δx2δy=x2δ(x1+x2+x3+x4)=1δx3δy=x3δ(x1+x2+x3+x4)=1δx4δy=x4δ(x1+x2+x3+x4)=1
def train(D, G, outdir, z_dimension, num_epochs=30):
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
os.makedirs(outdir, exist_ok=True)
for epoch in range(num_epochs):
for i, (img, lab) in enumerate(train_loader):
num_img = img.size(0)
# =================train discriminator
real_img = img.to(device) # 1024 * 1 * 28 * 28
y_one_hot = torch.zeros(lab.shape[0], 10).scatter_(1,
lab.view(lab.shape[0],1),1).to(device)
for ii in range(5): # 循环5次
d_optimizer.zero_grad()
real_out = D(real_img,y_one_hot) # shape为[batch]
z = torch.randn(num_img, z_dimension).to(device)
fake_img = G(z, y_one_hot) # shape为[batch, 1, 28, 28]
fake_out = D(fake_img, y_one_hot)
gradient_penalty = compute_gradient_penalty(D,
real_img.data, fake_img.data,y_one_hot)
d_loss = -torch.mean(real_out) + torch.mean(fake_out) + gradient_penalty
d_loss.backward()
d_optimizer.step()
# ===============train generator
# compute loss of fake_img
for ii in range(1):
g_optimizer.zero_grad()
z = torch.randn(num_img, z_dimension).to(device)
fake_img = G(z, y_one_hot)
fake_out = D(fake_img, y_one_hot)
g_loss = - torch.mean(fake_out)
g_loss.backward()
g_optimizer.step()
fake_images = to_img(fake_img.cpu().data)
real_images = to_img(real_img.cpu().data)
rel = torch.cat([to_img(real_images[:10]),fake_images[:10]],axis = 0)
imshow(torchvision.utils.make_grid(rel,nrow=10),
os.path.join(outdir, 'fake_images-{}.png'.format(epoch+1) ) )
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'
.format(epoch, num_epochs, d_loss.data, g_loss.data,
real_out.data.mean(), fake_out.data.mean()))
torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth' ) )
torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth' ) )
def displayAndTest(D,G,z_dimension):
# 可视化结果
sample = iter(test_loader)
images, labels = sample.next()
y_one_hot = torch.zeros(labels.shape[0],10).scatter_(1,
labels.view(labels.shape[0],1),1).to(device)
num_img = images.size(0)
with torch.no_grad():
z = torch.randn(num_img, z_dimension).to(device)
fake_img = G(z,y_one_hot)
fake_images = to_img(fake_img.cpu().data)
rel = torch.cat([to_img(images[:10]),fake_images[:10]],axis = 0)
imshow(torchvision.utils.make_grid(rel,nrow=10))
print(labels[:10])
if __name__ == '__main__':
z_dimension = 40 # noise dimension
D = WGAN_D().to(device) # discriminator model
G = WGAN_G(z_dimension).to(device) # generator model
train(D,G,'./w_img',z_dimension)
displayAndTest(D,G,z_dimension)