生成对抗网络GAN

一、生成对抗网络的结构图

生成对抗网络GAN_第1张图片
GAN分为两部分,包括生成器和判别器,先训练判别器,看做是一个二元分类器,可以判别是生成的图像还是真实图像。当判别器训练到一定阶段,freeze固定住其参数,开始训练生成器。使用生成器生成fake图像,使得判别器可以误认为是real数据(loss很大),训练目标就是让生成器越来越能生成像real的图像。当达到比较好的效果时,再freeze生成器的参数,开始训练判别器,提高判别器的判别能力。以此交替进行,最后模型可以生成很接近real的图像了。

二、GAN_LOSS

生成对抗网络GAN_第2张图片

例子: 让网络可以生成在x平方 ~ 2*x^2 + 1 之间的数据。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np


torch.manual_seed(1)
np.random.seed(1)

LR_G = 0.0001
LR_D = 0.0001
BATCH_SIZE = 64
N_IDEAS = 5

ART_COMPONETS = 15
# 竖着堆叠 把[-1,1]均匀切割
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONETS) for _ in range(BATCH_SIZE)])

# plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')    #2 * x^2 + 1
# plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')    #   x^2
# plt.legend(loc='upper right')           #标签位置
# plt.show()


# 获取一个batch的标准区间数据
def artist_work():
    # 为了随机出系数和bias,从一个均匀分布[low,high)中随机采样  (np.newaxis是增加一个维度)
    a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
    paints = a * np.power(PAINT_POINTS,2) + (a-1)
    paints = torch.from_numpy(paints).float()
    return paints


G = nn.Sequential(
    nn.Linear(N_IDEAS,128),
    nn.ReLU(),
    nn.Linear(128,ART_COMPONETS)
)
D = nn.Sequential(
    nn.Linear(ART_COMPONETS,128),
    nn.ReLU(),
    nn.Linear(128,1),
    nn.Sigmoid()
)

optimizer_G = torch.optim.Adam(G.parameters(),lr=LR_G)
optimizer_D = torch.optim.Adam(D.parameters(),lr=LR_D)

# 打开交互模式
plt.ion()

for step in range(10000):

    # 获取一个batch的标准区间数据
    artist_painting = artist_work()

    # 从[-1,1]中随机出一个batch的input,inputsize=N_IDEAS
    G_idea = torch.randn(BATCH_SIZE,N_IDEAS)

    # 经过model最终维度变为ART_COMPONETS(15),得到一个batch的生成数据。
    G_paintings = G(G_idea)

    # 判别器分别评测两种数据,得到对两种数据评测的打分。
    pro_atrist0 = D(artist_painting)
    pro_atrist1 = D(G_paintings)

    # G生成的数据经D判别结果越接近1,loss越低。这是对于G优化的方向。
    G_loss = -torch.mean(torch.log(pro_atrist1))
    # G生成的数据经D判别越接近0,标准数据经D判别越接近1,loss越低。这是D的优化方向。
    D_loss = -torch.mean(torch.log(pro_atrist0)+torch.log(1-pro_atrist1))

    optimizer_G.zero_grad()
    G_loss.backward(retain_graph=True)
    # 不可以在这里更新G网络,因为下面D_loss.backward()的时候,需要用到pro_atrist1,这个量的计算需要利用G_paintings,而G_paintings依赖于G网络。
    # 如果G网络提前更新了,就会导致在进行D_loss.backward()的时候报错。
    # optimizer_G.step()

    optimizer_D.zero_grad()
    D_loss.backward()
    optimizer_G.step()
    optimizer_D.step()



    if step % 200 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % pro_atrist0.data.numpy().mean(), fontdict={'size': 13})

        # plt.text(-.5, 2, 'G_loss= %.2f ' % G_loss.data.numpy(), fontdict={'size': 13})
        # print('D accuracy=%.2f'%pro_atrist0.data.numpy().mean())
        # print('G_loss= %.2f ' % G_loss.data.numpy())



        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)

plt.ioff()
plt.show()

你可能感兴趣的:(pytorch,深度学习,生成对抗网络,深度学习,计算机视觉)