【pytorch基础笔记三】基础GAN基于MINST生成手写数字

【参考资料】
【1】《python深度学习:基于PyTorch》
【2】《深入浅出GAN生成对抗网络》

 基于pyrotch的GAN手写数字生成方案很简单,但实际参考网上资料编写到调试通还是花费一定时间 :(

GAN网络基本思路如下:

生成器: 多层神经网络,最后一层为全连接,从某一个随机序列生成一组噪声图片
判别器: 多层神经网络,最后一层为sigmoid, 判断是否为真实图片

启动对抗过程为往复下面步骤1和2
步骤1:训练判别器
1.1 用当前的生成器生成一批伪造图片
1.2 从数据集中取一批真实图片
1.3 利用上述两批数据训练判别器
1.4 此步骤的目的用于使得判别器区分生成图片和真实图片,前者期望结果为1,后者期望结果为0
1.5 反向传导优化判别器参数,注意只优化判别器
步骤2:训练生成器
2.1 用当前的生成器构造一张图片
2.2 利用【步骤1】的判别器对2.1 的图片进行判断,此时需要让判断器判断其成功,即期望结果为1
2.3 利用 2.2 打出的结果作为误差训练生成器,注意此时只优化生成器

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image


"""
模型常量定义
"""

image_size       = 28 #图像像素28
num_epochs       = 100  #循环次数
train_batch_size = 128 #mini-batch训练数量

# 生成器相关参数
g_input_dim  = 100
g_hidden_1   = 256
g_hidden_2   = 256
g_hidden_3   = 256
g_output_dim = image_size**2

#判别器模型相关参数
d_input_dim  = image_size**2
d_hidden_1   = 256
d_hidden_2   = 256
d_hidden_3   = 256
d_output_dim = 1

def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  
    out = out.view(-1, 1, 28, 28)  
    return out

"""
生成器模型

nn.BatchNorm1d: 一维的归一化
备注:会报错 batchsize 是因为数据没有按照批次进,不符合mini-batch

"""
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),  
            nn.ReLU(True),  
            nn.Linear(256, 256),  
            nn.ReLU(True),  
            nn.Linear(256, 784),  
            nn.Tanh()  
        )
 
    def forward(self, x):
        x = self.gen(x)
        return x    

"""
判别器模型 
""" 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),  
            nn.LeakyReLU(0.2),  
            nn.Linear(256, 256), 
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  
        )

    def forward(self, x):
        x = self.dis(x)
        return x
# 实例化模型、损失函数等
d_learning_rate = 3e-4  # 3e-4
g_learning_rate = 3e-4
optim_betas     = (0.9, 0.999)
criterion       = nn.BCELoss()  #损失函数 - 二进制交叉熵
G = Generator()
D = Discriminator()
g_optimizer = optim.Adam(G.parameters(), lr=d_learning_rate)
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate)

#数据集构造
transform     = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
data_loader   = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

def _show_traindata(data): 
    img_data = data[0][0].detach().numpy()
    img_data = (img_data - -1)*125
    plt.imshow(img_data,cmap='gray')
     
for epoch in range(num_epochs): 
    for index, (imgs, _) in enumerate(data_loader):
        #步骤1:训练判别器
        #将真实图片放入判别器,训练判别器认为真实图片为 1
        train_batch_size = imgs.size(0)
        imgs = imgs.view(train_batch_size, -1)  
        d_real_decision = D(Variable(imgs))
        d_real_label    = Variable(torch.ones(train_batch_size))
        d_real_error    = criterion(d_real_decision, d_real_label)
        
        
        #将fake图片放入判别器,训练判别器认为真实图片为0
        d_fake_input = Variable(torch.randn(train_batch_size, g_input_dim))
        d_fake_imgs = G(d_fake_input).detach()
        d_fake_decision = D(d_fake_imgs)
        d_fake_label = Variable(torch.zeros(train_batch_size))

        
        d_fake_error = criterion(d_fake_decision, d_fake_label)
        d_loss = d_fake_error + d_real_error
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step() 
        
        
        #步骤2:训练生成器
        g_fake_input    = Variable(torch.randn(train_batch_size, g_input_dim))
        g_fake_imgs     = G(g_fake_input)
        g_fake_decision = D(g_fake_imgs)
        g_fake_label    = Variable(torch.ones(train_batch_size))
        g_fake_error    = criterion(g_fake_decision, g_fake_label)
        g_optimizer.zero_grad()
        g_fake_error.backward()
        g_optimizer.step() 
        
        if (index + 1) % 100 == 0:  
            real_images = to_img(g_fake_imgs.cpu().data)
            save_image(real_images, './img/test.png')
            
    print("Epoch[{}],d_loss:{:.6f}".format(epoch,d_loss.data.item()))

训练结果:

【pytorch基础笔记三】基础GAN基于MINST生成手写数字_第1张图片

你可能感兴趣的:(机器学习笔记,机器学习)