论文传送门:https://arxiv.org/pdf/1701.07875.pdf
参考文章:令人拍案叫绝的Wasserstein GAN - 知乎
WGAN的目的:解决GAN的梯度不稳定、多样性不足的问题。
WGAN的思想:使用Wasserstein距离代替JS散度,来描述生成分布与真实分布的距离。
WGAN的实现:与GAN相比,有四处不同:
①判别器D去掉最后一层sigmoid激活函数,使得判别器D的作用变为计算近似的Wasserstein距离(代码13-31行);
②生成器和判别器的loss不再取log,近似Wasserstein距离(代码131,140行);
③判别器参数更新时,限制其值在[-c,c]区间内,使D(x)满足Lipschitz连续条件(代码134-135行);
④不采用基于动量的优化器(Adam等),使用RMSprop或SGD(代码113-114行)。
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
class Discriminator(nn.Module): # 定义判别器(WS-divergence)
def __init__(self, img_shape=(1, 28, 28)): # 初始化方法
super(Discriminator, self).__init__() # 继承初始化方法
self.img_shape = img_shape # 图片形状
self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512) # linear映射
self.linear2 = nn.Linear(512, 256) # linear映射
self.linear3 = nn.Linear(256, 1) # linear映射
self.leakyrelu = nn.LeakyReLU(0.2, inplace=True) # leakyrelu激活函数
def forward(self, x): # 前传函数
x = torch.flatten(x, 1) # 输入图片从三维压缩至一维特征向量,(n,1,28,28)-->(n,784)
x = self.linear1(x) # linear映射,(n,784)-->(n,512)
x = self.leakyrelu(x) # leakyrelu激活函数
x = self.linear2(x) # linear映射,(n,512)-->(n,256)
x = self.leakyrelu(x) # leakyrelu激活函数
x = self.linear3(x) # linear映射,(n,256)-->(n,1)
return x # 返回近似拟合的Wasserstein距离
class Generator(nn.Module): # 定义生成器
def __init__(self, img_shape=(1, 28, 28), latent_dim=100): # 初始化方法
super(Generator, self).__init__()
self.img_shape = img_shape # 图片形状
self.latent_dim = latent_dim # 噪声z的长度
self.linear1 = nn.Linear(self.latent_dim, 128) # linear映射
self.linear2 = nn.Linear(128, 256) # linear映射
self.bn2 = nn.BatchNorm1d(256, 0.8) # bn操作
self.linear3 = nn.Linear(256, 512) # linear映射
self.bn3 = nn.BatchNorm1d(512, 0.8) # bn操作
self.linear4 = nn.Linear(512, 1024) # linear映射
self.bn4 = nn.BatchNorm1d(1024, 0.8) # bn操作
self.linear5 = nn.Linear(1024, self.img_shape[0] * self.img_shape[1] * self.img_shape[2]) # linear映射
self.leakyrelu = nn.LeakyReLU(0.2, inplace=True) # leakyrelu激活函数
self.tanh = nn.Tanh() # tanh激活函数,将输出压缩至(-1.1)
def forward(self, z): # 前传函数
z = self.linear1(z) # linear映射,(n,100)-->(n,128)
z = self.leakyrelu(z) # leakyrelu激活函数
z = self.linear2(z) # linear映射,(n,128)-->(n,256)
z = self.bn2(z) # 一维bn操作
z = self.leakyrelu(z) # leakyrelu激活函数
z = self.linear3(z) # linear映射,(n,256)-->(n,512)
z = self.bn3(z) # 一维bn操作
z = self.leakyrelu(z) # leakyrelu激活函数
z = self.linear4(z) # linear映射,(n,512)-->(n,1024)
z = self.bn4(z) # 一维bn操作
z = self.leakyrelu(z) # leakyrelu激活函数
z = self.linear5(z) # linear映射,(n,1024)-->(n,784)
z = self.tanh(z) # tanh激活函数
z = z.view(-1, self.img_shape[0], self.img_shape[1], self.img_shape[2]) # 从一维特征向量扩展至三维图片,(n,784)-->(n,1,28,28)
return z # 返回生成的图片
if __name__ == "__main__":
# 训练参数
total_epochs = 100 # 训练轮次
batch_size = 64 # 批大小
lr = 5e-5 # 学习率
num_workers = 8 # 数据加载线程数
latent_dim = 100 # 噪声z长度
image_size = 28 # 图片尺寸
channel = 1 # 图片通道
clip_value = 0.01 # 判别器参数限定范围
dataset_dir = "dataset/mnist" # 训练数据集路径
gen_images_dir = "gen_images" # 生成样例图片路径
cuda = True if torch.cuda.is_available() else False # 设置是否使用cuda
os.makedirs(dataset_dir, exist_ok=True) # 创建训练数据集路径
os.makedirs(gen_images_dir, exist_ok=True) # 创建样例图片路径
image_shape = (channel, image_size, image_size) # 图片形状
# 模型
D = Discriminator(image_shape) # 实例化判别器
G = Generator(image_shape, latent_dim) # 实例化生成器
if cuda: # 如果使用cuda
D = D.cuda() # 模型加载到GPU
G = G.cuda() # 模型加载到GPU
# 数据集
transform = transforms.Compose( # 数据预处理方法
[transforms.Resize(image_size), # resize
transforms.ToTensor(), # 转为tensor
transforms.Normalize([0.5], [0.5])] # 标准化
)
dataloader = DataLoader( # dataloader
dataset=datasets.MNIST( # 数据集选取MNIST手写体数据集
root=dataset_dir, # 数据集存放路径
train=True, # 使用训练集
download=True, # 自动下载
transform=transform # 应用数据预处理方法
),
batch_size=batch_size, # 设置batch size
num_workers=num_workers, # 设置读取数据线程数
shuffle=True # 设置打乱数据
)
# 优化器
optimizer_D = torch.optim.RMSprop(D.parameters(), lr=lr) # 定义判别网络RMSprop优化器,传入学习率
optimizer_G = torch.optim.RMSprop(G.parameters(), lr=lr) # 定义生成网络RMSprop优化器,传入学习率
# 训练循环
for epoch in range(total_epochs): # 循环epoch
pbar = tqdm(total=len(dataloader), desc=f'Epoch {epoch + 1}/{total_epochs}', postfix=dict,
mininterval=0.3) # 设置当前epoch显示进度
for i, (real_imgs, _) in enumerate(dataloader): # 循环iter
if cuda: # 如果使用cuda
real_imgs = real_imgs.cuda() # 数据加载到GPU
bs = real_imgs.shape[0] # batchsize
# 开始训练判别网络D
optimizer_D.zero_grad() # 判别网络D清零梯度
z = torch.randn((bs, latent_dim)) # 生成输入噪声z,服从标准正态分布,长度为latent_dim
if cuda: # 如果使用cuda
z = z.cuda() # 噪声z加载到GPU
fake_imgs = G(z).detach() # 噪声z输入生成网络G,得到生成图片,并阻止其反向梯度传播
loss_D = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs)) # 判别网络D的损失函数
loss_D.backward() # 反向传播,计算当前梯度
optimizer_D.step() # 根据梯度,更新网络参数
for p in D.parameters(): # 遍历判别网络D的模型参数
p.data.clamp_(-clip_value, clip_value) # 将参数限制在[-clip_value,clip_value]区间
# 开始训练生成网络G
optimizer_G.zero_grad() # 生成网络G清零梯度
gen_imgs = G(z) # 噪声z输入生成网络G,得到生成图片
loss_G = -torch.mean(D(gen_imgs)) # 生成网络G的损失函数
loss_G.backward() # 反向传播,计算当前梯度
optimizer_G.step() # 根据梯度,更新网络参数
pbar.set_postfix(**{'D_loss': loss_D.item(), 'G_loss': loss_G.item()}) # 显示判别网络D和生成网络G的损失
pbar.update(1) # 步进长度
pbar.close() # 关闭当前epoch显示进度
save_image(gen_imgs.data[:25], "%s/ep%d.png" % (gen_images_dir, (epoch + 1)), nrow=5,
normalize=True) # 保存生成图片样例(5x5)