使用DCGAN训练faces数据集,最终实现生成二次元动漫头像。
最后虽然生成了动漫头像,但是一些细节还是和真实的图像差别较大,比如说眼睛大小,眼睛颜色等。
之后我会将MINIST数据集、Oxford17数据集、以及faces数据集在训练过程中不同轮次的输出结果做一个总结。
生成二次元动漫头像的程序依然是沿用data.py、model.py、net.py、main.py但具体的编程的细节呢有所改变。
之前MINIST以及Oxford17数据集的程序
这里:
【Pytorch】DCGAN实战(一):基于MINIST数据集的手写数字生成
【Pytorch】DCGAN实战(二):基于Oxord17的鲜花图像生成
Python版本为3.7
在这里不详细介绍了,网上有很多的安装教程,小伙伴们自行查找吧!
Pycharm
整体分为4个文件:data.py、model.py、net.py、main.py
from torch.utils.data import DataLoader
from torchvision import utils, datasets, transforms
class ReadData():
def __init__(self,data_path,image_size=64):
self.root=data_path
self.image_size=image_size
self.dataset=self.getdataset()
def getdataset(self):
#3.dataset
dataset = datasets.ImageFolder(root=self.root,
transform=transforms.Compose([
transforms.Resize(self.image_size),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
print(f'Total Size of Dataset: {len(dataset)}')
return dataset
def getdataloader(self,batch_size=128):
dataloader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0)
return dataloader
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, nz,ngf,nc):
super(Generator, self).__init__()
self.nz = nz
self.ngf = ngf
self.nc=nc
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(self.ngf, self.nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, ndf,nc):
super(Discriminator, self).__init__()
self.ndf=ndf
self.nc=nc
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
# state size. (1) x 1 x 1
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
import torch
import torch.nn as nn
from torchvision import utils, datasets, transforms
import time
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import os
class DCGAN():
def __init__(self,lr,beta1,nz, batch_size,num_showimage,device, model_save_path,figure_save_path,generator, discriminator, data_loader,):
self.real_label=1
self.fake_label=0
self.nz=nz
self.batch_size=batch_size
self.num_showimage=num_showimage
self.device = device
self.model_save_path=model_save_path
self.figure_save_path=figure_save_path
self.G = generator.to(device)
self.D = discriminator.to(device)
self.opt_G=torch.optim.Adam(self.G.parameters(), lr=lr, betas=(beta1, 0.999))
self.opt_D = torch.optim.Adam(self.D.parameters(), lr=lr, betas=(beta1, 0.999))
self.criterion = nn.BCELoss().to(device)
self.dataloader=data_loader
self.fixed_noise = torch.randn(self.num_showimage, nz, 1, 1, device=device)
self.img_list = []
self.G_loss_list = []
self.D_loss_list = []
self.D_x_list = []
self.D_z_list = []
def train(self,num_epochs):
loss_tep = 10
G_loss=0
D_loss=0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
#**********计时*********************
beg_time = time.time()
# For each batch in the dataloader
for i, data in enumerate(self.dataloader):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
x = data[0].to(self.device)
b_size = x.size(0)
lbx = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device)
D_x = self.D(x).view(-1)
LossD_x = self.criterion(D_x, lbx)
D_x_item = D_x.mean().item()
# print("log(D(x))")
z = torch.randn(b_size, self.nz, 1, 1, device=self.device)
gz = self.G(z)
lbz1 = torch.full((b_size,), self.fake_label, dtype=torch.float, device=self.device)
D_gz1 = self.D(gz.detach()).view(-1)
LossD_gz1 = self.criterion(D_gz1, lbz1)
D_gz1_item = D_gz1.mean().item()
# print("log(1 - D(G(z)))")
LossD = LossD_x + LossD_gz1
# print("log(D(x)) + log(1 - D(G(z)))")
self.opt_D.zero_grad()
LossD.backward()
self.opt_D.step()
# print("update LossD")
D_loss+=LossD
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
lbz2 = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) # fake labels are real for generator cost
D_gz2 = self.D(gz).view(-1)
D_gz2_item = D_gz2.mean().item()
LossG = self.criterion(D_gz2, lbz2)
# print("log(D(G(z)))")
self.opt_G.zero_grad()
LossG.backward()
self.opt_G.step()
# print("update LossG")
G_loss+=LossG
end_time = time.time()
# **********计时*********************
run_time = round(end_time - beg_time)
# print('lalala')
print(
f'Epoch: [{epoch + 1:0>{len(str(num_epochs))}}/{num_epochs}]',
f'Step: [{i + 1:0>{len(str(len(self.dataloader)))}}/{len(self.dataloader)}]',
f'Loss-D: {LossD.item():.4f}',
f'Loss-G: {LossG.item():.4f}',
f'D(x): {D_x_item:.4f}',
f'D(G(z)): [{D_gz1_item:.4f}/{D_gz2_item:.4f}]',
f'Time: {run_time}s',
end='\r\n'
)
# print("lalalal2")
# Save Losses for plotting later
self.G_loss_list.append(LossG.item())
self.D_loss_list.append(LossD.item())
# Save D(X) and D(G(z)) for plotting later
self.D_x_list.append(D_x_item)
self.D_z_list.append(D_gz2_item)
# # Save the Best Model
# if LossG < loss_tep:
# torch.save(self.G.state_dict(), 'model.pt')
# loss_tep = LossG
if not os.path.exists(self.model_save_path):
os.makedirs(self.model_save_path)
torch.save(self.D.state_dict(), self.model_save_path + 'disc_{}.pth'.format(epoch))
torch.save(self.G.state_dict(), self.model_save_path + 'gen_{}.pth'.format(epoch))
# Check how the generator is doing by saving G's output on fixed_noise
with torch.no_grad():
fake = self.G(self.fixed_noise).detach().cpu()
self.img_list.append(utils.make_grid(fake * 0.5 + 0.5, nrow=10))
print()
if not os.path.exists(self.figure_save_path):
os.makedirs(self.figure_save_path)
plt.figure(1,figsize=(8, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(self.G_loss_list[::10], label="G")
plt.plot(self.D_loss_list[::10], label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.axhline(y=0, label="0", c="g") # asymptote
plt.legend()
plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight')
plt.figure(2,figsize=(8, 4))
plt.title("D(x) and D(G(z)) During Training")
plt.plot(self.D_x_list[::10], label="D(x)")
plt.plot(self.D_z_list[::10], label="D(G(z))")
plt.xlabel("iterations")
plt.ylabel("Probability")
plt.axhline(y=0.5, label="0.5", c="g") # asymptote
plt.legend()
plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'D(x)D(G(z)).jpg', bbox_inches='tight')
fig = plt.figure(3,figsize=(5, 5))
plt.axis("off")
ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
# ani.to_html5_video()
ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif')
plt.figure(4,figsize=(8, 4))
# Plot the real images
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
real = next(iter(self.dataloader)) # real[0]image,real[1]label
plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))
# Load the Best Generative Model
# self.G.load_state_dict(
# torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))
self.G.eval()
# Generate the Fake Images
with torch.no_grad():
fake = self.G(self.fixed_noise).cpu()
# Plot the fake images
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0)
plt.imshow(fake)
# Save the comparation result
plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight')
plt.show()
def test(self,epoch):
# Size of the Figure
plt.figure(figsize=(8, 4))
# Plot the real images
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
real = next(iter(self.dataloader))#real[0]image,real[1]label
plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))
# Load the Best Generative Model
self.G.load_state_dict(torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))
self.G.eval()
# Generate the Fake Images
with torch.no_grad():
fake = self.G(self.fixed_noise.to(self.device))
# Plot the fake images
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
fake = utils.make_grid(fake * 0.5 + 0.5, nrow=10)
plt.imshow(fake.permute(1, 2, 0))
# Save the comparation result
plt.savefig(self.figure_save_path+'result.jpg', bbox_inches='tight')
plt.show()
from data import ReadData
from model import Discriminator, Generator, weights_init
from net import DCGAN
import torch
ngpu=1
ngf=64
ndf=64
nc=3
nz=100
lr=0.003
beta1=0.5
batch_size=100
num_showimage=100
data_path="./oxford17_class"
model_save_path="./models/"
figure_save_path="./figures/"
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')
dataset=ReadData(data_path)
dataloader=dataset.getdataloader(batch_size=batch_size)
G = Generator(nz,ngf,nc).apply(weights_init)
print(G)
D = Discriminator(ndf,nc).apply(weights_init)
print(D)
dcgan=DCGAN( lr,beta1,nz,batch_size,num_showimage,device, model_save_path,figure_save_path,G, D, dataloader)
dcgan.train(num_epochs=20)
训练过程中Generator和Discriminator的Loss曲线图(以200个epoch为例):
训练过程中Discriminator输出(以200个epoch为例):
链接:https://pan.baidu.com/s/15J6sZL3rCPLm2jZFEuyzNw
提取码:DGAN
https://blog.csdn.net/qq_42951560/article/details/112199229
https://blog.csdn.net/qq_42951560/article/details/110308336
如果运行有问题,欢迎给我私信留言!