pytorch实现简单GAN

1.什么是GAN(Generative Adversarial Networks)

2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文。没错,我说的就是《Generative Adversarial Nets》,这标志着生成对抗网络(GAN)的诞生,而这是通过对计算图和博弈论的创新性结合。他们的研究展示,给定充分的建模能力,两个博弈模型能够通过简单的反向传播(backpropagation)来协同训练。

这两个模型的角色定位十分鲜明。给定真实数据集 DATA,G 是生成器(generator),它的任务是生成能以假乱真的假数据;而 D 是判别器 (discriminator),它从真实数据集或者 G 那里获取数据, 然后做出判别真假的标记。Ian Goodfellow 的比喻是,G 就像一个赝品作坊,想要让做出来的东西尽可能接近真品,蒙混过关。而 D 就是文物鉴定专家,要能区分出真品和高仿(但在这个例子中,造假者 G 看不到原始数据,而只有 D 的鉴定结果——前者是在盲干)。

pytorch实现简单GAN_第1张图片

pytorch实现简单GAN_第2张图片

理想情况下,D 和 G 都会随着不断训练,做得越来越好——直到 G 基本上成为了一个“赝品制造大师”,而 D 因无法正确区分两种数据分布输给 G。

2.数学建模

设真实数据的概率分布为Pdata, 生成器生成数据的概率分布为Pg

规定D的输出代表输入为“真”的概率,则D的目标是:

若输入是真品,则提高D(x)

若输入是赝品,则降低D(x)

综合起来用数学语言描述如下

对于G来说,它的目标是尽可能提高生成数据被D判别为“真”的概率,用数学语言描述如下:

 也即

3.全局最优解

对D的目标函数经过推导可以得到最优的D*如下:

pytorch实现简单GAN_第3张图片

把D*代入G的目标可以推出:

pytorch实现简单GAN_第4张图片

 取等号。由此得出生成器生成数据的分布在最优解情况下就等于真实数据的分布

4.用Pytorch实现简单GAN

首先明确我们的目标——让G学会画sin x的图像

import numpy as np
import torch.nn as nn
import torch
import matplotlib.pyplot as plt


LR = 0.0001
BATCH_SIZE = 64
DATA_SIZE = 16
IDEA = 5
X = np.linspace(0, 2 * np.pi, DATA_SIZE)


def p_data(x):
    f = np.zeros((BATCH_SIZE, DATA_SIZE))
    for i in range(BATCH_SIZE):
        f[i] = np.sin(x)
    return f


G = nn.Sequential(
    nn.Linear(IDEA, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, DATA_SIZE)
)

D = nn.Sequential(
    nn.Linear(DATA_SIZE, 64),
    nn.ReLU(),
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

D_optimizer = torch.optim.Adam(D.parameters(), lr=LR)
G_optimizer = torch.optim.Adam(G.parameters(), lr=LR)

接下来训练网络

for step in range(10000):

    real = torch.tensor(p_data(X)).float()
    idea = torch.randn((BATCH_SIZE, IDEA))
    fake = G(idea)

    prob_real = D(real)
    prob_fake = D(fake)

    D_loss = -torch.mean((torch.log(prob_real) + torch.log(torch.tensor(1) - prob_fake)))
    G_loss = torch.mean(torch.log(torch.tensor(1) - prob_fake))

    D_optimizer.zero_grad()
    D_loss.backward(retain_graph=True)
    D_optimizer.step()

    G_optimizer.zero_grad()
    G_loss.backward()
    G_optimizer.step()

    if step % 100 == 0:
        print(prob_real.mean())
        print(prob_fake.mean())
        print('-----------------------------------------------')
    if torch.abs(prob_real.mean() - 0.5) <= 1.e-6:
        break
    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(X, fake.data.numpy()[0], c='red', lw=3, label='Generated painting')
        plt.plot(X, real.data.numpy()[0], c='black', lw=1, label='real painting')
        plt.text(1, .5, 'the prob of Generated painting is real = %.2f' % prob_fake.data.numpy().mean())
        plt.ylim((-1.1, 1.1))
        plt.legend(loc='best', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()
plt.show()

torch.save(G.state_dict(), './G_state_dict.pkl')

训练结果如下 

pytorch实现简单GAN_第5张图片

 

你可能感兴趣的:(pytorch,机器学习,神经网络,深度学习)