pytorch gan minist

原理省略

需要两个文件
gan.py
module_load.py

gan.py–训练代码

##argparse是python用于解析命令行参数和选项的标准模块
#使用步骤:
#1 import argparse
#2 parser=argparse.ArgumentParser()
#3 parser.add_argument()
#4 parser.parse_args()
import argparse
import os
import numpy as np
import math

#用于data augmentation
import torchvision.transforms as transforms
#保存生成图像
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch


##############################################################################
#设置模型中的某些参数,这里使用了argparse来集中操作

#如果根目录下不存在images文件夹,则创建images存放生成图像结果
os.makedirs('images',exist_ok=True)

#创建解析对象
parser=argparse.ArgumentParser()
#向解析对象中添加命令行参数和选项
#epoch = 20,屁大小=64,学习率=0.0002,衰减率=0.5/0.999,线程数=8,隐码维数=100,样本尺寸=28*28,通道数=1,样本间隔=400
parser.add_argument("--n_epochs",type=int,default=1,help="number of  epochs of training")
parser.add_argument("--batch_size",type=int,default=64,help="size of batches")
parser.add_argument("--lr",type=float,default=0.0002,help="adam:learning rate")
parser.add_argument("--b1",type=float,default=0.5,help="adam:decay of first order momentum of gradient")
parser.add_argument("--b2",type=float,default=0.999,help="adam:decay of first order momentum of gradient")
parser.add_argument("--n_cpu",type=int,default=8,help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim",type=int,default=100,help="dimensionality of the latent space")
parser.add_argument("--img_size",type=int,default=28,help="size of each image dimension")
parser.add_argument("--channels",type=int,default=1,help="number of image channels")
parser.add_argument("--sample_interval",type=int,default=400,help="interval between image samples")
#解析参数
opt=parser.parse_args()
print(opt)

img_shape=(opt.channels,opt.img_size,opt.img_size)# 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda=True if torch.cuda.is_available() else False

#################################################################################
#创建生成器G

#---------------------
#     生成器
#---------------------
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        def block(in_feat,out_feat,normalize=True):
            #这里简单的只对输入数据做线性转换
            layers=[nn.Linear(in_feat,out_feat)]
            #使用BN
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            #添加LeakyReLUE非线性激活层
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers

        #创建生成器网络模型
        self.model = nn.Sequential(
            *block(opt.latent_dim,128,normalize=False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,int(np.prod(img_shape))),
            nn.Tanh()
        )

    #前进
    def forward(self,z):
        #生成假样本
        img=self.model(z)
        img=img.view(img.size(0),*img_shape)
        #返回生成图像
        return img


###################################################################################
#创建判别器D
#-------------------
#     判别器
#-------------------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()

        self.model=nn.Sequential(
            nn.Linear(int(np.prod(img_shape)),512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(512,256),
            nn.Linear(256,1),
            #因需判别真假,这里使用Sigmoid函数给出标量的判别结果
            nn.Sigmoid(),
        )

    #判别
    def forward(self,img):
        img_flat=img.view(img.size(0),-1)
        validity=self.model(img_flat)
        #判别结果
        return validity


###################################################################################
#损失函数和优化器
adversarial_loss=torch.nn.BCELoss()

#初始化生成器和鉴别器
generator=Generator()
discriminator=Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

#优化器,G和D都使用adam
optimizer_G=torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

###################################################################################
#加载数据器
os.makedirs("../../data/mnist",exist_ok=True)
#--------------------------------------
#      torch.utils.data.DataLoader
#--------------------------------------
#数据加载器,结合了数据集和取样器,并且可以提供多个线性处理数据集。
#在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。
#直至把所有的数据都抛出。就是做一个数据的初始化
#
#torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
#                            batch_sampler=None, num_workers=0, collate_fn=,
#                            pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
# dataset:加载数据的数据集
#batch_size:每批次加载的数据量
#shuffle:默认false,若为True,表示在每个epoch打乱数据
#sampler:定义从数据集中绘制示例的策略,如果指定,shuffle必须为false
#...
#更多可参考: https://pytorch.org/docs/stable/data.html

#设置数据加载器,这里使用MNIST数据集
dataloader=torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

if __name__ == '__main__':
    ##############################################################################
    #训练模型
    for epoch in range(opt.n_epochs):
        for i, (imgs,_) in enumerate(dataloader):
            valid=Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad=False)
            fake=Variable(Tensor(imgs.size(0),1).fill_(0.0),requires_grad=False)

            #输入
            real_imgs=Variable(imgs.type(Tensor))

            #------------------------
            #训练G
            #------------------------

            optimizer_G.zero_grad()

            #采样随机噪声问题
            z=Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))

            #训练得到一批次生成样本
            gen_imgs=generator(z)

            #计算G的损失函数值
            g_loss=adversarial_loss(discriminator(gen_imgs),valid)

            #更新G
            g_loss.backward()
            optimizer_G.step()

            #---------------------
            #     训练D
            #---------------------

            optimizer_D.zero_grad()

            #评估D的判别能力
            real_loss=adversarial_loss(discriminator(real_imgs),valid)
            fake_loss=adversarial_loss(discriminator(gen_imgs.detach()),fake)
            d_loss=(real_loss+fake_loss)/2

            #更新D
            d_loss.backward()
            optimizer_D.step()

            if (i+1)%10==0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss:%f] [G loss: %f]"
                    %(epoch, opt.n_epochs,i, len(dataloader),d_loss.item(),g_loss.item())
                )

            #保存结果
            batches_done=epoch*len(dataloader)+i
            if batches_done % opt.sample_interval ==0:
                save_image(gen_imgs.data[:25],"images/%d.png" %batches_done,nrow=5,normalize=True)
            if (epoch+1)%1==0 and i+1==len(dataloader):
                print('save__')
                torch.save(generator,'g%d.pth'%epoch)
                torch.save(discriminator,'d%d.pth'%epoch)

module_load.py—导入最后保存生成器的模型

from gan import Generator, Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image


device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
Tensor=torch.FloatTensor
g=torch.load('g0.pth')#导入生成器Generator模型
#d=torch.load('d.pth')
g=g.to(device)
#d=d.to(device)

z = Variable(Tensor(np.random.normal(0, 1, (64, 100))))  #输入的噪音
gen_imgs=g(z)#生成图片
save_image(gen_imgs.data[:25],"images.png",nrow=5,normalize=True)

你可能感兴趣的:(pytorch,GAN)