深度学习-GAN生成式对抗网络

生成式对抗网络(GAN,generative adversarial network)的简单理解就是,想想一名伪造者试图伪造一幅毕加索的画作。一开始,伪造者非常不擅长这项任务,他随便画了幅与毕加索真迹放在一起,请鉴定商进行评估,鉴定商鉴定后,将结果反馈给伪造者,并告诉他怎样可以让❀看起来更像毕加索的真迹。伪造者学习后回去重新画,然后再拿给鉴定商鉴定,多次循环后,伪造者已经十分熟练的伪造毕加索的画作了,鉴定商的鉴定能力也有了很大的提高。最后,他们手上拥有了一些优秀的毕加索赝品。

下边的例子是使用GAN模拟二次方程:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

"""超参数"""
BATCH_SIZE = 64  # 每批数据个数
LR_G = 0.0001  # 生成器的学习率(伪造者)
LR_D = 0.0001  # 判别器的学习率(鉴定商)
N_IDEAS = 5  # 随机想法个数
ART_COMPONENTS = 15  # 线段上数据点个数
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])  # h恩坐标范围

"""创建圣经网络"""
def artist_works():
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)  # 定义原始一元二次方程(一开始的真品)
    paintings = torch.from_numpy(paintings).float()  # 转换为torch形式
    return paintings
G = nn.Sequential(  # 生成器(伪造者)
    nn.Linear(N_IDEAS, 128),  # 输入随即想法
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS)  # 生成15个点连线(创造一个赝品)
)
D = nn.Sequential(  # 判别器(鉴定上)
    nn.Linear(ART_COMPONENTS, 128),  # 接受生成器生成的数据(获得赝品)
    nn.ReLU(),
    nn.Linear(128, 1),  # 判别是否和原始数据相似(鉴定赝品是真是假)
    nn.Sigmoid()  # 产生百分比,表示是什么数据(表示是真品还是赝品)
)
opt_D = torch.optim.RMSprop(D.parameters(), lr=LR_D)  # 优化判别器
opt_G = torch.optim.RMSprop(G.parameters(), lr=LR_G)  # 优化生成器

"""训练神经网络"""
plt.ion()
for step in range(5000):
    artist_paintings = artist_works()  # 先获取原始标准方程(一开始的真品)
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # 随机生成数据(想法)
    G_paintings = G(G_ideas)  # 生成器产生方程(创造赝品)

    prob_artist0 = D(artist_paintings)  # 计算式标准方程的概率(真品的概率)
    prob_artist1 = D(G_paintings)  # 计算式伪造方程的概率(赝品的概率)

    D_loss = -torch.mean(torch.log(prob_artist0) + torch.log(1 - prob_artist1))  # 增加标准方程的概率
    G_loss = torch.mean(torch.log(1 - prob_artist1))  # 增加伪造方程被认为是真方程的概率

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    """循环打印"""
    if step % 100 == 0:  # 每100步打印一次
        plt.cla()  # 清空上一次的
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3,
                 label='Generated painting', )  # 随机生成的方程
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3,
                 label='upper bound')  # 标准方程上界
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3,
                 label='lower bound')  # 标准方程下界
        plt.ylim((0, 3))
        plt.legend(loc='upper right', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()
plt.show()

深度学习-GAN生成式对抗网络_第1张图片

你可能感兴趣的:(深度学习,python,PyTorch,深度学习,GAN)