原理省略
需要两个文件
gan.py
module_load.py
gan.py–训练代码
##argparse是python用于解析命令行参数和选项的标准模块
#使用步骤:
#1 import argparse
#2 parser=argparse.ArgumentParser()
#3 parser.add_argument()
#4 parser.parse_args()
import argparse
import os
import numpy as np
import math
#用于data augmentation
import torchvision.transforms as transforms
#保存生成图像
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
##############################################################################
#设置模型中的某些参数,这里使用了argparse来集中操作
#如果根目录下不存在images文件夹,则创建images存放生成图像结果
os.makedirs('images',exist_ok=True)
#创建解析对象
parser=argparse.ArgumentParser()
#向解析对象中添加命令行参数和选项
#epoch = 20,屁大小=64,学习率=0.0002,衰减率=0.5/0.999,线程数=8,隐码维数=100,样本尺寸=28*28,通道数=1,样本间隔=400
parser.add_argument("--n_epochs",type=int,default=1,help="number of epochs of training")
parser.add_argument("--batch_size",type=int,default=64,help="size of batches")
parser.add_argument("--lr",type=float,default=0.0002,help="adam:learning rate")
parser.add_argument("--b1",type=float,default=0.5,help="adam:decay of first order momentum of gradient")
parser.add_argument("--b2",type=float,default=0.999,help="adam:decay of first order momentum of gradient")
parser.add_argument("--n_cpu",type=int,default=8,help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim",type=int,default=100,help="dimensionality of the latent space")
parser.add_argument("--img_size",type=int,default=28,help="size of each image dimension")
parser.add_argument("--channels",type=int,default=1,help="number of image channels")
parser.add_argument("--sample_interval",type=int,default=400,help="interval between image samples")
#解析参数
opt=parser.parse_args()
print(opt)
img_shape=(opt.channels,opt.img_size,opt.img_size)# 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda=True if torch.cuda.is_available() else False
#################################################################################
#创建生成器G
#---------------------
# 生成器
#---------------------
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
def block(in_feat,out_feat,normalize=True):
#这里简单的只对输入数据做线性转换
layers=[nn.Linear(in_feat,out_feat)]
#使用BN
if normalize:
layers.append(nn.BatchNorm1d(out_feat,0.8))
#添加LeakyReLUE非线性激活层
layers.append(nn.LeakyReLU(0.2,inplace=True))
return layers
#创建生成器网络模型
self.model = nn.Sequential(
*block(opt.latent_dim,128,normalize=False),
*block(128,256),
*block(256,512),
*block(512,1024),
nn.Linear(1024,int(np.prod(img_shape))),
nn.Tanh()
)
#前进
def forward(self,z):
#生成假样本
img=self.model(z)
img=img.view(img.size(0),*img_shape)
#返回生成图像
return img
###################################################################################
#创建判别器D
#-------------------
# 判别器
#-------------------
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.model=nn.Sequential(
nn.Linear(int(np.prod(img_shape)),512),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(512,256),
nn.Linear(256,1),
#因需判别真假,这里使用Sigmoid函数给出标量的判别结果
nn.Sigmoid(),
)
#判别
def forward(self,img):
img_flat=img.view(img.size(0),-1)
validity=self.model(img_flat)
#判别结果
return validity
###################################################################################
#损失函数和优化器
adversarial_loss=torch.nn.BCELoss()
#初始化生成器和鉴别器
generator=Generator()
discriminator=Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
#优化器,G和D都使用adam
optimizer_G=torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
###################################################################################
#加载数据器
os.makedirs("../../data/mnist",exist_ok=True)
#--------------------------------------
# torch.utils.data.DataLoader
#--------------------------------------
#数据加载器,结合了数据集和取样器,并且可以提供多个线性处理数据集。
#在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。
#直至把所有的数据都抛出。就是做一个数据的初始化
#
#torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
# batch_sampler=None, num_workers=0, collate_fn=,
# pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
# dataset:加载数据的数据集
#batch_size:每批次加载的数据量
#shuffle:默认false,若为True,表示在每个epoch打乱数据
#sampler:定义从数据集中绘制示例的策略,如果指定,shuffle必须为false
#...
#更多可参考: https://pytorch.org/docs/stable/data.html
#设置数据加载器,这里使用MNIST数据集
dataloader=torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
if __name__ == '__main__':
##############################################################################
#训练模型
for epoch in range(opt.n_epochs):
for i, (imgs,_) in enumerate(dataloader):
valid=Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad=False)
fake=Variable(Tensor(imgs.size(0),1).fill_(0.0),requires_grad=False)
#输入
real_imgs=Variable(imgs.type(Tensor))
#------------------------
#训练G
#------------------------
optimizer_G.zero_grad()
#采样随机噪声问题
z=Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))
#训练得到一批次生成样本
gen_imgs=generator(z)
#计算G的损失函数值
g_loss=adversarial_loss(discriminator(gen_imgs),valid)
#更新G
g_loss.backward()
optimizer_G.step()
#---------------------
# 训练D
#---------------------
optimizer_D.zero_grad()
#评估D的判别能力
real_loss=adversarial_loss(discriminator(real_imgs),valid)
fake_loss=adversarial_loss(discriminator(gen_imgs.detach()),fake)
d_loss=(real_loss+fake_loss)/2
#更新D
d_loss.backward()
optimizer_D.step()
if (i+1)%10==0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss:%f] [G loss: %f]"
%(epoch, opt.n_epochs,i, len(dataloader),d_loss.item(),g_loss.item())
)
#保存结果
batches_done=epoch*len(dataloader)+i
if batches_done % opt.sample_interval ==0:
save_image(gen_imgs.data[:25],"images/%d.png" %batches_done,nrow=5,normalize=True)
if (epoch+1)%1==0 and i+1==len(dataloader):
print('save__')
torch.save(generator,'g%d.pth'%epoch)
torch.save(discriminator,'d%d.pth'%epoch)
module_load.py—导入最后保存生成器的模型
from gan import Generator, Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
Tensor=torch.FloatTensor
g=torch.load('g0.pth')#导入生成器Generator模型
#d=torch.load('d.pth')
g=g.to(device)
#d=d.to(device)
z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #输入的噪音
gen_imgs=g(z)#生成图片
save_image(gen_imgs.data[:25],"images.png",nrow=5,normalize=True)