涉及的论文
GAN
https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
DCGAN
https://arxiv.org/pdf/1511.06434.pdf
测试用的数据集
Celeb-A Faces
数据集网站:
http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
下载链接:
百度 网盘 :https://pan.baidu.com/s/1eSNpdRG#list/path=%2F
谷歌 网盘 :https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
数据集下载后,找到一个文件叫 img_align_celeba.zip
创建一个文件夹data,然后在data内创建一个文件夹celeba.
将img_align_celeba.zip 拷贝进celeba,然后解压
unzip img_align_celeba.zip
会生成这样的目录结构
./data/celeba/
->img_align_celeba
->188242.jpg
->173822.jpg
->284792.jpg
...
这一步很重要,因为我们的代码中使用这样的文件结构.
实现DCGAN 包含的文件
main.py
etc.py
graph.py
model.py
show.py
record.py
DCGAN_architecture.py
celeba_dataset.py
main.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-01-25 14:07
# Modified date : 2019-01-27 22:36
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import celeba_dataset
from etc import config
from graph import NNGraph
def run():
dataloader = celeba_dataset.get_dataloader(config)
g = NNGraph(dataloader, config)
g.train()
run()
etc.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-24 17:02
# Modified date : 2019-01-28 23:37
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import torch
config = {}
config["dataset"] = "celeba"
config["batch_size"] = 128
config["image_size"] = 64
config["num_epochs"] = 5
config["data_path"] = "data/%s" % config["dataset"]
config["workers"] = 2
config["print_every"] = 200
config["save_every"] = 500
config["manual_seed"] = 999
config["train_load_check_point_file"] = False
#config["manual_seed"] = random.randint(1, 10000) # use if you want new results
config["number_channels"] = 3
config["size_of_z_latent"] = 100
config["number_gpus"] = 1
config["number_of_generator_feature"] = 64
config["number_of_discriminator_feature"] = 64
config["learn_rate"] = 0.0002
config["beta1"] =0.5
config["real_label"] = 1
config["fake_label"] = 0
config["device"] = torch.device("cuda:0" if (torch.cuda.is_available() and config["number_gpus"] > 0) else "cpu")
graph.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-24 17:17
# Modified date : 2019-01-28 17:46
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import os
import time
import torch
import torchvision.utils as vutils
import model
import show
import record
class NNGraph(object):
def __init__(self, dataloader, config):
super(NNGraph, self).__init__()
self.config = config
self.train_model = self._get_train_model(config)
record.record_dict(self.config, self.train_model["config"])
self.config = self.train_model["config"]
self.dataloader = dataloader
def _get_train_model(self, config):
train_model = model.init_train_model(config)
train_model = self._load_train_model(train_model)
return train_model
def _save_train_model(self):
model_dict = model.get_model_dict(self.train_model)
file_full_path = record.get_check_point_file_full_path(self.config)
torch.save(model_dict, file_full_path)
def _load_train_model(self, train_model):
file_full_path = record.get_check_point_file_full_path(self.config)
if os.path.exists(file_full_path) and self.config["train_load_check_point_file"]:
checkpoint = torch.load(file_full_path)
train_model = model.load_model_dict(train_model, checkpoint)
return train_model
def _train_step(self, data):
netG = self.train_model["netG"]
optimizerG = self.train_model["optimizerG"]
netD = self.train_model["netD"]
optimizerD = self.train_model["optimizerD"]
criterion = self.train_model["criterion"]
device = self.config["device"]
real_data = data[0].to(device)
noise = model.get_noise(real_data, self.config)
fake_data = netG(noise)
label = model.get_label(real_data, self.config)
errD, D_x, D_G_z1 = model.get_Discriminator_loss(netD, optimizerD, real_data, fake_data.detach(), label, criterion, self.config)
errG, D_G_z2 = model.get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, self.config)
return errD, errG, D_x, D_G_z1, D_G_z2
def _train_a_step(self, data, i, epoch):
start = time.time()
errD, errG, D_x, D_G_z1, D_G_z2 = self._train_step(data)
end = time.time()
step_time = end - start
self.train_model["take_time"] = self.train_model["take_time"] + step_time
print_every = self.config["print_every"]
if i % print_every == 0:
record.print_status(step_time*print_every,
self.train_model["take_time"],
epoch,
i,
errD,
errG,
D_x,
D_G_z1,
D_G_z2,
self.config,
self.dataloader)
return errD, errG
def _DCGAN_eval(self):
fixed_noise = self.train_model["fixed_noise"]
with torch.no_grad():
netG = self.train_model["netG"]
fake = netG(fixed_noise).detach().cpu()
return fake
def _save_generator_images(self, iters, epoch, i):
num_epochs = self.config["num_epochs"]
save_every = self.config["save_every"]
img_list = self.train_model["img_list"]
if (iters % save_every == 0) or ((epoch == num_epochs-1) and (i == len(self.dataloader)-1)):
fake = self._DCGAN_eval()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
self._save_train_model()
def _train_iters(self):
num_epochs = self.config["num_epochs"]
G_losses = self.train_model["G_losses"]
D_losses = self.train_model["D_losses"]
iters = self.train_model["current_iters"]
start_epoch = self.train_model["current_epoch"]
for epoch in range(start_epoch, num_epochs):
self.train_model["current_epoch"] = epoch
for i, data in enumerate(self.dataloader, 0):
errD, errG = self._train_a_step(data, i, epoch)
G_losses.append(errG.item())
D_losses.append(errD.item())
iters += 1
self.train_model["current_iters"] = iters
self._save_generator_images(iters, epoch, i)
def train(self):
self._train_iters()
show.show_images(self.train_model, self.config, self.dataloader)
model.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : model.py
# Create date : 2019-01-24 17:00
# Modified date : 2019-01-29 00:43
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import random
import torch
import torch.nn as nn
import torch.optim as optim
from DCGAN_architecture import Generator, Discriminator
import record
def _weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def _random_init(config):
manualSeed = config["manual_seed"]
random.seed(manualSeed)
torch.manual_seed(manualSeed)
def _get_a_net(Net, config):
ngpu = config["number_gpus"]
device = config["device"]
net = Net(config).to(device)
if (device.type == 'cuda') and (ngpu > 1):
net = nn.DataParallel(net, list(range(ngpu)))
net.apply(_weights_init)
record.save_status(config, net)
return net
def _get_optimizer(net, config):
lr = config["learn_rate"]
beta1 = config["beta1"]
opt = optim.Adam(net.parameters(), lr=lr, betas=(beta1, 0.999))
return opt
def _get_fixed_noise(config):
nz = config["size_of_z_latent"]
device = config["device"]
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
return fixed_noise
def load_model_dict(train_model, checkpoint):
train_model["netG"].load_state_dict(checkpoint["netG"])
train_model["netD"].load_state_dict(checkpoint["netD"])
train_model["criterion"].load_state_dict(checkpoint["criterion"])
train_model["optimizerD"].load_state_dict(checkpoint["optimizerD"])
train_model["optimizerG"].load_state_dict(checkpoint["optimizerG"])
train_model["fixed_noise"] = checkpoint["fixed_noise"]
train_model["G_losses"] = checkpoint["G_losses"]
train_model["D_losses"] = checkpoint["D_losses"]
train_model["img_list"] = checkpoint["img_list"]
train_model["current_iters"] = checkpoint["current_iters"]
train_model["current_epoch"] = checkpoint["current_epoch"]
train_model["config"] = checkpoint["config"]
train_model["take_time"] = checkpoint["take_time"]
return train_model
def get_model_dict(train_model):
model_dict = {}
model_dict["netG"] = train_model["netG"].state_dict()
model_dict["netD"] = train_model["netD"].state_dict()
model_dict["criterion"] = train_model["criterion"].state_dict()
model_dict["optimizerD"] = train_model["optimizerD"].state_dict()
model_dict["optimizerG"] = train_model["optimizerG"].state_dict()
model_dict["fixed_noise"] = train_model["fixed_noise"]
model_dict["G_losses"] = train_model["G_losses"]
model_dict["D_losses"] = train_model["D_losses"]
model_dict["img_list"] = train_model["img_list"]
model_dict["current_iters"] = train_model["current_iters"]
model_dict["current_epoch"] = train_model["current_epoch"]
model_dict["config"] = train_model["config"]
model_dict["take_time"] = train_model["take_time"]
return model_dict
def init_train_model(config):
_random_init(config)
netG = _get_a_net(Generator, config)
netD = _get_a_net(Discriminator, config)
criterion = nn.BCELoss()
optimizerD = _get_optimizer(netD, config)
optimizerG = _get_optimizer(netG, config)
fixed_noise = _get_fixed_noise(config)
train_model = {}
train_model["netG"] = netG
train_model["netD"] = netD
train_model["criterion"] = criterion
train_model["optimizerD"] = optimizerD
train_model["optimizerG"] = optimizerG
train_model["fixed_noise"] = fixed_noise
train_model["G_losses"] = []
train_model["D_losses"] = []
train_model["img_list"] = []
train_model["current_iters"] = 0
train_model["current_epoch"] = 0
train_model["config"] = config
train_model["take_time"] = 0.0
return train_model
def _run_Discriminator(netD, data, label, loss):
output = netD(data).view(-1)
err = loss(output, label)
err.backward()
m = output.mean().item()
return err, m
def get_Discriminator_loss(netD, optimizerD, real_data, fake_data, label, criterion, config):
netD.zero_grad()
errD_real, D_x = _run_Discriminator(netD, real_data, label, criterion)
label.fill_(config["fake_label"])
errD_fake, D_G_z1 = _run_Discriminator(netD, fake_data, label, criterion)
errD = errD_real + errD_fake
optimizerD.step()
return errD, D_x, D_G_z1
def get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, config):
netG.zero_grad()
label.fill_(config["real_label"]) # fake labels are real for generator cost
errG, D_G_z2 = _run_Discriminator(netD, fake_data, label, criterion)
optimizerG.step()
return errG, D_G_z2
def get_label(data, config):
b_size = data.size(0)
real_label = config["real_label"]
device = config["device"]
label = torch.full((b_size, ), real_label, device=device)
return label
def get_noise(data, config):
b_size = data.size(0)
device = config["device"]
nz = config["size_of_z_latent"]
noise = torch.randn(b_size, nz, 1, 1, device=device)
return noise
show.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : show.py
# Create date : 2019-01-24 17:19
# Modified date : 2019-01-28 17:31
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import numpy as np
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from matplotlib import rcParams
from matplotlib.animation import ImageMagickWriter
import record
rcParams["animation.embed_limit"] = 500
def show_some_batch(real_batch,device):
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.show()
def _plot_real_and_fake_images(real_batch, device, img_list, save_path):
plt.figure(figsize=(30, 30))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
name = "real_and_fake.jpg"
full_path_name = "%s/%s" % (save_path, name)
plt.savefig(full_path_name)
#plt.show()
def _show_generator_images(G_losses, D_losses, save_path):
plt.figure(figsize=(40, 20))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
name = "G_D_losses.jpg"
full_path_name = "%s/%s" % (save_path, name)
plt.savefig(full_path_name)
#plt.show()
def _show_img_list(img_list):
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
plt.show()
def _save_img_list(img_list, save_path, config):
#_show_img_list(img_list)
metadata = dict(title='generator images', artist='Matplotlib', comment='Movie support!')
writer = ImageMagickWriter(fps=1,metadata=metadata)
ims = [np.transpose(i, (1, 2, 0)) for i in img_list]
fig, ax = plt.subplots()
with writer.saving(fig, "%s/img_list.gif" % save_path,500):
for i in range(len(ims)):
ax.imshow(ims[i])
ax.set_title("step {}".format(i * config["save_every"]))
writer.grab_frame()
def show_images(train_model, config, dataloader):
G_losses = train_model["G_losses"]
D_losses = train_model["D_losses"]
img_list = train_model["img_list"]
save_path = record.get_check_point_path(config)
_show_generator_images(G_losses, D_losses, save_path)
_save_img_list(img_list,save_path,config)
real_batch = next(iter(dataloader))
_plot_real_and_fake_images(real_batch, config["device"], img_list, save_path)
record.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : record.py
# Create date : 2019-01-28 15:51
# Modified date : 2019-01-28 18:07
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import os
def _get_param_str(config):
# pylint: disable=bad-continuation
param_str = "%s_%s_%s_%s_%s_%s_%s" % (
config["dataset"],
config["image_size"],
config["batch_size"],
config["number_of_generator_feature"],
config["number_of_discriminator_feature"],
config["size_of_z_latent"],
config["learn_rate"],
)
# pylint: enable=bad-continuation
return param_str
def get_check_point_path(config):
param_str = _get_param_str(config)
directory = "%s/save/%s/" % (config["data_path"], param_str)
if not os.path.exists(directory):
os.makedirs(directory)
return directory
def get_check_point_file_full_path(config):
path = get_check_point_path(config)
param_str = _get_param_str(config)
file_full_path = "%s%scheckpoint.tar" % (path, param_str)
return file_full_path
def _write_output(config, con):
save_path = get_check_point_path(config)
file_full_path = "%s/output" % save_path
f = open(file_full_path, "a")
f.write("%s\n" % con)
f.close()
def record_dict(config, dic):
save_status(config, "config:")
for key in dic:
dic_str = "%s : %s" % (key, dic[key])
save_status(config, dic_str)
def save_status(config, con):
print(con)
_write_output(config, con)
def print_status(step_time, take_time, epoch, i, errD, errG, D_x, D_G_z1, D_G_z2, config, dataloader):
num_epochs = config["num_epochs"]
# pylint: disable=bad-continuation
print_str = '[%d/%d]\t[%d/%d]\t Loss_D: %.4f\t Loss_G: %.4f\t D(x): %.4f\t D(G(z)): %.4f / %.4f take_time: %.fs' % (
epoch,
num_epochs,
i,
len(dataloader),
errD.item(),
errG.item(),
D_x,
D_G_z1,
D_G_z2,
take_time,
)
# pylint: enable=bad-continuation
save_status(config, print_str)
DCGAN_architecture.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : DCGAN_architecture.py
# Create date : 2019-01-26 23:16
# Modified date : 2019-01-27 22:47
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, config):
super(Generator, self).__init__()
self.ngpu = config["number_gpus"]
nz = config["size_of_z_latent"]
ngf = config["number_of_generator_feature"]
nc = config["number_channels"]
# pylint: disable=bad-continuation
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
# pylint: enable=bad-continuation
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, config):
super(Discriminator, self).__init__()
self.ngpu = config["number_gpus"]
ndf = config["number_of_discriminator_feature"]
nc = config["number_channels"]
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
celeba_dataset.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : celeba_dataset.py
# Create date : 2019-01-24 18:02
# Modified date : 2019-01-26 22:57
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
def get_dataloader(config):
image_size = config["image_size"]
batch_size = config["batch_size"]
dataroot = config["data_path"]
workers = config["workers"]
tf = transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = dset.ImageFolder(root=dataroot, transform=tf)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers)
return dataloader
运行2个epoch后生成的结果
Loss_G 和 Loss_D 对比图
真实图片和假图片对比
运行5个 epoch后的结果]
Loss_G 和 Loss_D 的对比图
真实图片和假图片对比
运行200个epoch 后结果
真实图片和假图片对比
github :https://github.com/darr/DCGAN