GAN造图MINST手写数字

文章目录

  • 1.生成对抗网络介绍
  • 2.代码讲解
    • 2.1 导入库文件、设定维度和批次
    • 2.2 载入数据集
    • 2.3 定义鉴别网络
    • 2.4 生成网络
    • 2.5 定义网络实例、优化器和损失函数
    • 2.6 网络训练过程
    • 2.7 保存模型
    • 2.8 训练结果图

1.生成对抗网络介绍

生成对抗网络GAN(Generate Adversarial Network)包含两个部分,生成网络和对抗网络。

  • 生成网络(generator):给定一个简单的高维的正态分布的噪声向量,通过生成网络的WX+b和卷积池化等操作将其变为一个与输入的真实图片大小相同的矩阵,所生成的图片输入对抗网络判断真假,尽可能提高它所生成图像被判定为真的概率。
  • 鉴别网络(discriminator):就是一个实现二分类(01分类,仅输出一个数值为0~1的数)的网络,可用CNN,RNN或DNN实现(本文介绍CNN和DNN实现)。鉴别网络要尽可能地提高它分辨图像真假的能力。
    GAN造图MINST手写数字_第1张图片

GAN的工作过程如上图所示。在训练过程中,每个epoch中轮流对鉴别网络和生成网络进行训练。

  • 首先训练鉴别网络。1)将真实图片输入鉴别网络,输出真实图像分数。鉴别网络需要训练其分数尽可能地接近1(真),计算其损失d_loss_real。2)然后随机生成向量,输入生成网络,输出与真实图像大小相同的矩阵,将其输入鉴别网络中,输出虚假图像分数。鉴别网络需要训练其分数尽可能接近0(假),计算d_loss_fake
  • 然后训练生成网络。同样是随机生成向量,输入生成网络,输出与真实图像大小相同的矩阵,将其输入鉴别网络中,输出虚假图像分数。生成网络需要训练虚假图片分数尽可能接近1,计算其损失g_loss

2.代码讲解

本代码在pytorch框架下实现GAN网络进行MINST手写数据库造图。

2.1 导入库文件、设定维度和批次

导入库文件

import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms as T
from torchvision.utils import save_image
import os

设定维度和批次

# 定义一个生成网络生成图片的保存地址
if not os.path.exists('./img'):
    os.mkdir('./img')
   
def to_img(x):
    out = 0.5*(x+1)   # 在读取图像时使用transform进行了归一化,这里进行反归一化 
    out = out.clamp(0,1)   # 限制out的值在0-1
    out = out.view(-1,1,28,28)  # 展平为图片矩阵形状
    return out

batch_size = 64   # 设定一个批次有64张图片
num_epoch = 100   # 训练周期为100
z_dimension = 100  # 正态分布噪声向量维度

2.2 载入数据集

# 定义加载时的变换,转化为张量并归一化
img_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=(0.5,), std=(0.5,))
])
# GAN目标是造假图,不区分训练集测试集
dataset = datasets.MNIST(root='G:/dl_dataset/',transform=img_transform)
dataloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

2.3 定义鉴别网络

DNN版本

# 判别器为普通二分类网络
# 输入batchsize*784,输出batchsize
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784,256),   # x在输入时已经展平,每张图片变为一个784维向量
            nn.LeakyReLU(0.2),    # 激活层
            nn.Linear(256,256),   # 全连接层
            nn.LeakyReLU(0.2),    # 激活层
            nn.Linear(256,1),     # 全连接层
            nn.Sigmoid()          # 激活层,将得分变为0-1的数
        )
    def forward(self, x):
        x = self.dis(x) 
        return x

CNN版本

# 输入 batchsize*1*28*28 输出 batchsize
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),  # batch, 32, 28, 28
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2),  # batch, 32, 14, 14
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 14, 14
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2)  # batch, 64, 7, 7
        )
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)           # block1包括卷积层、激活、平均池化
        x = self.conv2(x)           # block2包括卷积层、激活、平均池化
        x = x.view(x.size(0), -1)   # 展平为向量
        x = self.fc(x)              # 通过全连接层、激活层、全连接层、池化层最后输出0-1的分数
        return x

2.4 生成网络

DNN版本

# 输入 batchsize*z_dimensions 输出 batchsize*784
class generator(nn.Module):
    def __init__(self,in_dim):
        super(generator,self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(True),
            nn.Linear(256,256),
            nn.ReLU(True),
            nn.Linear(256,784),
            nn.Tanh())
        
    def forward(self, x):
        x = self.gen(x)            # 输出batchsize*784维向量
        return x

CNN版本

# 输入 batchsize*z_dimension  输出 batchsize*784
class generator(nn.Module):
    def __init__(self, in_dim):
        super(generator, self).__init__()
        self.fc = nn.Linear(in_dim, 3136)  # batch, 3136=1x56x56
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.downsample1 = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1),  # batch, 50, 56, 56
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 56, 56
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.downsample3 = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2),  # batch, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)      # batch regulation 正则化
        x = self.downsample1(x)
        x = self.downsample2(x)
        x = self.downsample3(x)
        x = x.view(x.size(0),-1)
        return x      # 输出batchsize*784

2.5 定义网络实例、优化器和损失函数

D = discriminator()          # 鉴别网络
G = generator(z_dimension)   # 生成网络
if torch.cuda.is_available():
    D, G = D.cuda(), G.cuda()  # 如果gpu可用将模型放到gpu上
criterion = nn.BCELoss()  # 二分类的交叉熵
d_optimizer = optim.Adam(D.parameters(),lr=0.0003)  # 鉴别网络的优化器
g_optimizer = optim.Adam(G.parameters(),lr=0.0003)  # 生成网络的优化器

2.6 网络训练过程

# 训练周期为num_epoch
for epoch in range(num_epoch):

    print('*'*30)
    print('epoch{}'.format(epoch+1))
    
    for i,(img,_) in enumerate(dataloader):
    # 读入一次dataloader结构为 img,label,这里读入图片全为真实图片
        num_img = img.size(0)   # batch_size
        # -----train discriminator
        img = img.view(num_img,-1)  # 拉成 num_img,784   若使用CNN网络,这句省略
        real_img = img.cuda()    # 将图片放入gpu
        real_label = torch.ones(num_img).cuda()  # 真实图片标签放入gpu
        fake_label = torch.zeros(num_img).cuda() # 虚假图片标签放入gpu
        
        # compute loss of real img
        real_out = D(real_img)     # batchsize张真实图片的分数
        d_loss_real = criterion(real_out.squeeze(-1),real_label)  # 计算鉴别网络真实图片的损失
        real_scores = real_out # 1代表真,0代表假 越接近1越好
        
        # compute loss of fake img
        z = torch.randn(num_img,z_dimension).cuda()  #生成batchsize个随机向量
        fake_img = G(z)    # 随机向量输入生成网络输出虚假图像
        fake_out = D(fake_img)  # 虚假图像输入鉴别网络输出虚假图片分数
        d_loss_fake = criterion(fake_out.squeeze(-1),fake_label)  # 计算鉴别网络虚假图片的损失
        fake_scores = fake_out  # 越接近0越好
        
        d_loss = d_loss_real + d_loss_fake   # 计算鉴别网络的总损失
        d_optimizer.zero_grad()  # 梯度清0
        d_loss.backward()    # 反向传播计算梯度
        d_optimizer.step()    # 优化器前进
        
        # ---------train generator 
        z = torch.randn(num_img,z_dimension).cuda()  #生成随机向量
        fake_img = G(z)    # 生成虚假图片
        output = D(fake_img)     # 输入鉴别网络计算虚假图片分数
        g_loss = criterion(output.squeeze(-1),real_label)   # 得到假的图片与真实图片label的loss,要以假乱真
        
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
    if epoch == 0:  # 保存一张真实图像
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')
    print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                      epoch+1, num_epoch, d_loss.item(), g_loss.item(),
                      real_scores.data.mean(), fake_scores.data.mean()))
    fake_images = to_img(fake_img.cpu().data)  # 保存虚假图像
    save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

2.7 保存模型

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

2.8 训练结果图

epoch10
GAN造图MINST手写数字_第2张图片
epoch30
GAN造图MINST手写数字_第3张图片
epoch50
GAN造图MINST手写数字_第4张图片
epoch70
GAN造图MINST手写数字_第5张图片
epoch100
GAN造图MINST手写数字_第6张图片
真实图像
GAN造图MINST手写数字_第7张图片

可以看出所生成的虚假图像越来越像真实图像。

你可能感兴趣的:(深度学习笔记)