GAN生成对抗网络以及pytorch实现

                **GAN生成对抗网络以及pytorch实现**

首先GAN近些年来在图像生成等相关方面有很多的应用,我就刚开始看GAN说下自己的理解。
GAN中分为两个模块,一个是生成器Generator(以下用G表示),判别器Discriminator(以下用D表示),这两个我们都可以看成一个神经网络。
例如我们针对Mnist数据集,如果,图像的可以的维度是2828,那么G就是从一个向量生成一个2828的矩阵,这个矩阵如果使用线性网络就是的得到一个1784的向量,比如向量是1100的随机向量,(这个向量可以通过随机生成得到,也就是噪声)我们通过Linear得到一个1784然后reshape得到2828的图片。判别器D就是从这个图片中得到一个分数,我们理想的目标是判别器对真实图片识别得到的分数越高越好,即更接近1,对生成器G生成的图片得分越低越好,更接近0。如下图所示的大概过程(中间的hidden layer其实就是生成器G的输出):

GAN生成对抗网络以及pytorch实现_第1张图片
那么知道了整体的网络结构,下面介绍GAN的前向传播和反向传播过程:
GAN生成对抗网络以及pytorch实现_第2张图片

  1. 选取m个真实的图片样本
  2. 选取m个噪声样本,每一个噪声样本为一个向量,向量的长度可以自己确定,分布可以任选一个分布,比如正态分布等。
  3. 将噪声输入生成器G中,m个样本会得到m个图片,如果是线性就是m*784的向量。
  4. 然后对判别器进行训练,我们要得到的结果就是判别器对真实图片的得分更高,对虚假的图片得分更低,也就是真实图片更接近1,虚假图片更接近0,在具体代码中,可以计算对真实图片的得分和1的交叉熵loss_real,对虚假图片的得分和0交叉熵loss_fake,那么判别器的总损失为两个相加,并且两个都是越小越好。
  5. 然后进行生成器的训练和更新,在生成器,要是生成器生成的图片更接近真实的图片,也就还是D(G(z))的值越接近于1,生成器就越好,那么生成器的损失就是g_loss为D(G(z))和1的交叉熵,那么也就是越小越好。
    使用pytorch进行实现的代码如下:
# GAN的核心就是类似一个造假器
# 拥有一个Generator和Discriminator,G和N都是一个neural network
# G即使生成器,生成器可以从一个随机的向量中得到一个图片,
# D即判别器,判别器就是判断从G中生成的图片是否为真实的图片,图片会存在分类,真实的图片为1,G生成的图片为0即为假
# 1.初始化G和N
# 2.in each training iteration:
#    (1).首先固定住G的参数,然后去更新D的参数,目的是为了使D中识别出G的图片分数更低,然后对真实的图片分数更高。
#    (2).然后固定住D,更新G的参数,在更新参数的过程中,可以将G和N当作一整个网络,例如G和N都是五层,那么网络就是10层, 在10层中间,会有一个hidden layer,即使一个图片的向量,
# 例如64*64的向量,那么也就是说整个网络输入的是一个向量,然后在这个过程中
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
import os
from torchvision.utils import save_image
import matplotlib.pyplot as plt
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081, ))
    ])
train_dataset = datasets.MNIST(
        root='../dataset/mnist',
        train=True,
        download=False,
        transform=transform
        )
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True
)
test_dataset = datasets.MNIST(
    root='../dataset/mnist',
    train=False,
    download=False,
    transform=transform
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False
)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784)
        )
        # self.l1 = torch.nn.Linear(784, )
        # self.l2 = torch.nn.Linear()
        # self.l3 = torch.nn.Linear()
    def forward(self, x): #x为一个张量 最终得到一个图片的张量
        x = x.reshape(x.size(0), -1)
        x = self.layer1(x)
        return x

class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        x = self.layer1(x)
        return x

G = Generator()
D = Discriminator()

criterion = nn.BCELoss()
d_optomizer = optim.Adam(D.parameters(), lr=0.0003)
g_optomizer = optim.Adam(G.parameters(), lr=0.0003)

d_loss_train = []
g_loss_train = []
x = []
def train(epoch):
    for i, (input, target) in enumerate(train_loader):
        num_img = input.size(0)
        real_label = torch.ones(num_img)
        fake_label = torch.zeros(num_img)



        #训练D,判别器

        real_label = real_label.reshape(real_label.size(0), -1)
        fake_label = fake_label.reshape(fake_label.size(0), -1)

        real_out = D(input)
        d_loss_real = criterion(real_out, real_label)
        real_score = real_out

        z = torch.randn((num_img, 100))
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)

        d_loss = d_loss_fake + d_loss_real
        d_optomizer.zero_grad()
        d_loss.backward()
        d_optomizer.step()

        #训练生成器
        z = torch.randn(num_img, 100)
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        g_optomizer.zero_grad()
        g_loss.backward()
        g_optomizer.step()


        g_loss_train.append(g_loss.item() / num_img)
        d_loss_train.append(d_loss.item() / num_img)
        x.append(i)



        if i % 5 == 0:

            print(g_loss.item() / num_img, d_loss.item() / num_img)


train(1)
torch.save(G, 'generator_test1.pkl')
torch.save(D, 'discriminator_test1.pkl')
plt.plot(x, d_loss_train)
plt.plot(x, g_loss_train)
plt.show()











你可能感兴趣的:(pytorch,深度学习,generator)