PyTorch - GAN与WGAN及其实战

目录

  • GAN
    • 基本结构
    • 训练
      • 对于生成器
      • 对于判别器
      • 训练流程
      • 训练理论
        • min max公式
        • Where will D converge, given fixed G
        • Where will G converge, after optimal D
  • GAN的实现
    • 导入包以及设置参数变量
    • Generator
    • Discriminator
    • data_generator
    • 训练过程
    • 结果分析
  • WGAN
    • GAN的缺点
    • WGAN原理
  • WGAN的实现
    • gradient_penalty
    • 训练过程的变化部分

GAN

基本结构

PyTorch - GAN与WGAN及其实战_第1张图片

训练

对于生成器

  • 输入一个n维度向量
  • 输出图片像素大小的图片

对于判别器

  • 输入图片
  • 输出图片的真伪标签

训练流程

PS:LaTex语法大全

  1. 初始化判别器D的参数 θ d θ_{d} θd和生成器G的参数 θ g θ_{g} θg
  2. 从真实样本中采样 m m m个样本 { x 1 , x 2 , . . . , x m x_{1}, x_{2}, ..., x_{m} x1,x2,...,xm } ,从先验分布噪声中采样 m m m个噪声样本 { z 1 , z 2 , . . . , z m z_{1}, z_{2}, ..., z_{m} z1,z2,...,zm } ,并通过生成器获取 m m m个生成样本 { x ~ 1 , x ~ 2 , . . . , x ~ m \widetilde x_{1}, \widetilde x_{2}, ..., \widetilde x_{m} x 1,x 2,...,x m }
  3. 循环k次更新判别器之后,使用较小的学习率来更新一次生成器的参数
  4. 多次更新迭代之后,最终理想情况是使得判别器判别不出样本来自于生成器的输出还是真实的输出,即最终样本判别概率均为0.5

训练理论

min max公式

概率论期望部分补充
在这里插入图片描述

Where will D converge, given fixed G

证明参考
详细证明(1):GAN(生成式对抗网络)的变量代换步骤
证明(1)中的 单位冲激函数

PyTorch - GAN与WGAN及其实战_第2张图片

Where will G converge, after optimal D

KL散度
PyTorch - GAN与WGAN及其实战_第3张图片

GAN的实现

该实战中不使用图片数据集,而是自己生成高斯混合分布作为数据。

导入包以及设置参数变量

import torch
from torch import nn, optim
import numpy as np
import random
from    matplotlib import pyplot as plt

h_dim = 400
batchsz = 512

Generator

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        # 4层
        self.net = nn.Sequential(
            # z: [b, 2] => xg: [b, 2]  (这个"2"对应的维度包含了x和y的坐标)
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
        )


    def forward(self, z):
        output = self.net(z)
        return output

Discriminator

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            # 转化为概率
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

data_generator

np.random.randn(shape) 用法
yield关键字用法

def data_generator():
    # 8-gaussian mixture models

	# 放大倍数
    scale = 2
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    # 放大一下
    centers = [(scale * x, scale * y) for x, y in centers]

    while True:
        dataset = []

        for i in range(batchsz):
            point = np.random.randn(2) * 0.02
            center = random.choice(centers)

            # N(0, 1) + center_x1/x2
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)

        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.414
        yield dataset

训练过程

torch.manual_seed(序号) 用法
np.random.seed(序号) 用法

optim.Adam()中的betas参数:
betas[0]对应下图的μ;betas[1]对应下图的ρ
PyTorch - GAN与WGAN及其实战_第4张图片

def main():

    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()

    G = Generator().cuda()
    D = Discriminator().cuda()

    optim_G = optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.9))
    optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9))

    for epoch in range(50000):

        # train Discriminator first
        for _ in range(5):
            # train on real data
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cuda()

            # [b, 2] => [b ,1]
            predr = D(xr)

            # maximize predr => minimize lossr   (看mixmax公式就能推出来要不要加负号)
            lossr = - predr.mean()


            # train on fake data
            # [b, 2]
            z = torch.randn(batchsz, 2).cuda()
            # 不让梯度传到Generator,因为只优化Discriminator
            xf = G(z).detach()
            predf = D(xf)
            lossf = predf.mean()

            # aggregate all
            loss_D = lossr + lossf

            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        # 2. train Generator
        z = torch.randn(batchsz, 2).cuda()
        xf = G(z)
        predf = D(xf)   # 这个位置不能加detach()
        loss_G = - predf.mean()

        # optimization
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:

            print(loss_D.item(), loss_G.item())

if __name__ == '__main__':
    main()

结果分析

代码中不包含可视化部分,故直接截图视频中的可视化结果
PyTorch - GAN与WGAN及其实战_第5张图片
PyTorch - GAN与WGAN及其实战_第6张图片
可以看到 loss 快速收敛,很快就没有梯度了。所以生成的分布也没法继续靠近原数据分布。

WGAN

GAN的缺点

难训练,不稳定。生成器和判别器之间需要很好的同步,但是在实际训练中很容易D收敛,G发散。

WGAN原理

  • 在训练 D 的 loss 中加入一项 gradient penalty
  • 通过插值法找到 x x x x ~ \widetilde x x 之间的中间态,并计算其梯度
  • 将其梯度的二范数约束在 1 附近
    PyTorch - GAN与WGAN及其实战_第7张图片

WGAN的实现

WGAN 在实现上与上边的 GAN 只是在 D 的训练 Loss 中加入了 gradient penalty 项

gradient_penalty

torch.rand(size) 用法
expand_as(Tensor) 用法
requires_grad_() 用法
torch.autograd.grad(outputs, inputs) 用法
torch.ones_like(Tensor) 用法

注意:为什么要使用 autograd.grad() 来计算?
因为PyTorch的自动求导是针对 weights 的求导,所以我们要手动设置这一求导过程。

def gradient_penalty(D, xr, xf):
    """

    :param D:
    :param xr: [b, 2]
    :param xf:  [b, 2]
    """
    t = torch.rand(batchsz, 1).cuda()
    # [b, 1] => [b, 2]
    t = t.expand_as(xr)

    mid = t * xr + (1 - t) * xf
    # set it requires gradient
    mid.requires_grad_()

    pred = D(mid)
    grads = autograd.grad(outputs=pred, inputs=mid,
                          grad_outputs=torch.ones_like(pred),
                          create_graph=True, retain_graph=True,
                          only_inputs=True)

    # 二范数和1越接近越好
    gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()

    return gp

训练过程的变化部分

调用写好的函数,求出gp,并加入到 D 总的 Loss 中,为了平衡各部分 Loss 的尺度,乘上一个参数。

            # gradient penalty
            gp = gradient_penalty(D, xr, xf.detach())

            # aggregate all
            loss_D = lossr + lossf + 0.2 * gp

你可能感兴趣的:(深度学习,神经网络,生成对抗网络,pytorch)