最近在学习换脸相关的技术,在看完FaceShifter的论文和代码后就对GAN的思想产生了兴趣。而在看完陈云大佬的《深度学习框架PyTorch:入门与实践》GAN那一章后,就产生了用GAN生成动漫头像的念头。
GAN,又叫生成对抗网络是一种非监督的学习。该网络中有一个生成器(Generators)和判别器(Discriminators),而训练过程就是这两个网络不断博弈对抗的过程。生成器不断生成假图企图通过判别器的识别,而判别器则将图片划分为真实图像和生成图像。
网上做这种模型的人非常多,所以动漫头像的数据集也非常多。不过我大致看了一下,网上的数据集中的动漫头像都非常古老,颇有90年代日本动漫的画风(可能都是老二次元)。在这个模型中,我用的是自己在一个网站爬下来的数据。网站链接在这/ | konachan.net - Konachan.com Anime Wallpapers。
import time
import requests
import tqdm
from bs4 import BeautifulSoup
import os
import traceback# python异常模块
# 爬取图片
def download(url,filename,proxies):
# 判断此时文件是否存在
if os.path.exists(filename):
print('file exists')
return
try:
time.sleep(1)
r = requests.get(url,stream=True,timeout=60,proxies=proxies)# 以流数据形式请求
r.raise_for_status()
with open(filename,'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:# 当这个文件存在时
f.write(chunk)
f.flush()
return filename
except KeyboardInterrupt:
if os.path.exists(filename):# 此时出错说明该文件不存在任何数据,若保存过该文件则删除
os.remove(filename)
raise KeyboardInterrupt
except Exception:
traceback.print_exc()# 把返回信息输出到控制台
if os.path.exists(filename):
os.remove(filename)
if os.path.exists('imgs') is False:
os.makedirs('imgs')
proxy = '127.0.0.1:58591' #
proxies = {
'http': 'http://' + proxy,
'https': 'https://' + proxy
}
start = 1
end = 8000# 8k张图片
for i in tqdm.tqdm(range(start,end + 1),desc="download anime picture ing ~"):# tqdm括号内的必须是一个迭代器
time.sleep(1)
url = 'https://konachan.net/post?page=%d&tags=' % i# 网站
html = requests.get(url,verify=True, proxies=proxies).text# 获取html网页上的内容
soup = BeautifulSoup(html,'html.parser')
for img in soup.find_all('img',class_="preview"):# 找到原网站中含有图片文件网站
target_url = img['src']
filename = os.path.join('imgs/true_imgs',target_url.split('/')[-1])
download(target_url,filename,proxies)
可能是网站的原因,若不加sleep()会返回连接超时的报错,我猜可能是访问的太频繁了。不过具体原因我也不太清楚,对爬虫这一块不是很熟悉。
从这个网站爬下来的图片都是一些动漫壁纸,可能有些包含人物,而有些不包含。这里我用了openCV的一块模块来识别图像中的头像,并把它截取下来做为接下来训练的数据。
# 从动漫壁纸中截取人物头像
import cv2
import sys
import os
from glob import glob
# 截取
def detect(filename,cascade_file = "lbpcascade_animeface.xml"):
if not os.path.isfile(cascade_file):
raise RuntimeError("%s: not found" % cascade_file)
cascade = cv2.CascadeClassifier(cascade_file)# 目标检测
image = cv2.imread(filename)# 打开图片
gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
gray = cv2.equalizeHist(gray)
faces = cascade.detectMultiScale(gray,
scaleFactor=1.1,
minNeighbors=5,
minSize=(48,48))
for i ,(x,y,w,h) in enumerate(faces):
face = image[y:y+h,x:x+w,:]# 得到图像像素点的分布
face = cv2.resize(face,(96,96))
save_filename = '{}-{}.jpg'.format(os.path.basename(filename).split('.')[0],i)
cv2.imwrite("data/faces/"+save_filename,face)# 写入文件
if __name__ == '__main__':
if os.path.exists('data/faces') is False:
os.makedirs('data/faces')
file_list = glob('imgs/true_imgs/*.jpg')# 将imgs中所有图片路径整合为一个迭代器
for filename in file_list:
detect(filename)
这两块,我都借鉴了这位大佬的代码利用GAN生成动漫头像_一个追逐自我的程序员的博客-CSDN博客
# 生成器
class NetG(nn.Module):
def __init__(self,opt):
super(NetG,self).__init__()
ngf = opt.ngf# 生成器feature map数
# 生成器主要的网络模块
self.main = nn.Sequential(
# 输入是一个nz维的噪音,是一个随机生成的张量,可以认为是大小为1x1的feature amp
nn.ConvTranspose2d(opt.nz,ngf*8,kernel_size=4,stride=1,padding=0,bias = False),# 反卷积,做上采样
nn.BatchNorm2d(ngf*8),
nn.ReLU(True),
# 上一步的输出形状:(ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf*8,ngf*4,kernel_size=4,stride=2,padding=1,bias=False),# 继续上采样,不断减小图片维度
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
# 上一步的输出形状: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1,bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
# 上一步的输出形状: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1,bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 上一步的输出形状:(ngf) x 32 x 32
nn.ConvTranspose2d(ngf,3,kernel_size=5,stride=3,padding=1,bias=False),
nn.Tanh()# 输出范围固定在 -1 ~ 1故而采用Tanh
# 输出形状:3 x 96 x 96
# feature map经过解码过程,最后生成一个图片
)
def forward(self,input):
return self.main(input)
# 判别器
class NetD(nn.Module):
def __init__(self,opt):
super(NetD,self).__init__()
ndf = opt.ndf
self.main = nn.Sequential(
# 输入3*96*96即生成器生成的图片
nn.Conv2d(3,ndf,kernel_size=5,stride=3,padding=1,bias=False),# 卷积,下采样,也是编码的过程
nn.LeakyReLU(0.2,inplace=True),
# 输出 ndf x 32 x32
nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1,bias=False),# 正好将feature map图片大小缩小一半
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2,inplace=True),
# 输出 (ndf*2) x 16 x 16
nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1,bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2,inplace=True),
# 输出 (ndf*4) x 8 x 8
nn.Conv2d(ndf*4,ndf*8,kernel_size=4,stride=2,padding=1,bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2,inplace=True),
# 输出 (ndf*8) x 4 x 4
nn.Conv2d(ndf*8,1,kernel_size=4,stride = 1,padding=0,bias=False),# 最后编码成为一个维度为1的向量
nn.Sigmoid()# 最后用Sigmoid作为分类,使得判别器成为一个判断二分类问题的模型,实际上判别器也是做一个二分类任务,判断是否为原图输出0或1
)
def forward(self,input):
return self.main(input).view(-1)# 转成一个列向量,即sigmoid的结果在更前面的维度
我个人觉得,这个GAN中的下采样再上采样的过程应该也借鉴了U-Net的网络结构。不过这里是将噪音编码为图像,再将图像解码为一个score。
def model(device,pth = False):
netg = NetG(Config).to(device)
netd = NetD(Config).to(device)
if pth:
netg.load_state_dict(torch.load(Config.load_G))
netd.load_state_dict(torch.load(Config.load_D))
return netg,netd
class Config():
data_path = 'data/'
num_workers = 4
image_size = 96# 输入和输出的图片尺寸
batch_size = 64
max_epoch = 4000
lr_G = 2e-4# 生成器的学习率
lr_D = 2e-4# 判别器的学习率
beta1 = 0.5# Adam优化器的beta1参数
nz = 100# 产生的噪音维度
ngf = 64# 生成器feature map数
ndf = 64# 判别器feature map数
save_img_path = 'generate_img'# 生成的图片保存路径
save_model_G_path = 'ppppth/G'
save_model_D_path = 'ppppth/D'
load_D = 'ppppth/D/Anime_GAN_Dlast.pth'
load_G = 'ppppth/G/Anime_GAN_Glast.pth'
vis = True# 是否使用可视化
env = 'GAN'
plot_every = 20# 每间隔20 batch,visdom画图一次
d_every = 1# 每一个batch训练一次判别器
g_every = 5# 每五个batch训练一次生成器
save_every = 20# 每20个epoch保存一次模型
# 只测试不训练
gen_img = 'imgs/generate_head/result.png'# 从512张生成的图片中保存最好的64张
gen_num = 64
gen_search_num = 512
gen_mean = 0 # 噪声的均值
gen_std = 1 # 噪声的方差
这里的dataset我使用的是torch自带的ImageFolder
import os
import torch
import visdom
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms as T
from model import NetD,NetG
from config import Config
from visualize import Visualizer
def train():
# data
data_path = Config.data_path
image_size = Config.image_size
transform = transforms()
batch_size = Config.batch_size
vis = Visualizer(Config.env)
datasets = torchvision.datasets.ImageFolder(data_path,transform=transform)# 使用这个ImageFolder时,图片的路径必须是所处文件夹的上一级路径,即是data/而不是data/faces/
dataloader = DataLoader(
datasets,
batch_size=batch_size,
shuffle=True,
num_workers=Config.num_workers,
drop_last=True
)
# model
device = 'cuda'
G,D = model(device=device)
# 优化器和损失函数
lr_G = Config.lr_G# 生成器学习率
lr_D = Config.lr_D# 判别器学习率
beta = Config.beta1
optimizer_G = torch.optim.Adam(G.parameters(),lr=lr_G,betas=(beta,0.999))
optimizer_D = torch.optim.Adam(D.parameters(),lr=lr_D,betas=(beta,0.999))
criterion = torch.nn.BCELoss().to(device)# 因为最终是一个二分类的问题
# 标签,0为假图片,1为真图片
t_label = torch.ones(batch_size).to(device)
f_label = torch.zeros(batch_size).to(device)
# 噪音,用于生成图片
noise = torch.randn(batch_size,Config.nz,1,1).to(device)# 1x1大小的噪音
val_noise = torch.randn(batch_size,Config.nz,1,1).to(device)
epochs = Config.max_epoch
loss_add_g = torch.tensor(0.0,device=device)
loss_add_d = torch.tensor(0.0,device=device)
transform我没做什么特殊的数据增强
def transforms(): transforms = T.Compose([ T.CenterCrop(Config.image_size), T.ToTensor(), T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ]) return transforms
for epoch in range(epochs):
for i,(img,_) in enumerate(dataloader):
img = img.to(device)
# 致力于让生成的假图骗过判别器
if (i+1) % Config.g_every == 0:
# 训练生成器
optimizer_G.zero_grad()
noise.data.copy_(torch.randn(batch_size,Config.nz,1,1))# 使每次训练生成器时,噪音不同
fake_img = G(noise)
out = D(fake_img)
loss = criterion(out,t_label)
loss_add_g += loss
loss.backward()
optimizer_G.step()
loss_G_mean = loss_add_g / (i+1)
# 致力于让判别器能识别出真图和假图
if (i+1) % Config.d_every == 0:
# 训练判别器,训练判别器要训练两部分
optimizer_D.zero_grad()
# 尽可能让判别器识别图片为真
real_output = D(img)
loss_r = criterion(real_output,t_label)# 使判别器尽量识别出源图片是真图片
loss_r.backward()
# 尽可能让判别器识别为假
noise.data.copy_(torch.randn(batch_size,Config.nz,1,1))
fake_img = G(noise)# 根据噪音生成图片
fake_output = D(fake_img)
loss_f = criterion(fake_output,f_label)# 使判别器尽量识别出生成的图片是假的图片
loss_f.backward()
loss = loss_f + loss_r
optimizer_D.step()
loss_add_d += loss
loss_D_mean = loss_add_d / (i+1)
# 每隔plot_every个batch在visdom上画一次图
if Config.vis and i % Config.plot_every == Config.plot_every - 1:
generate_img = G(val_noise)
vis.images(generate_img.detach().cpu().numpy()[:64]*0.5+0.5,win='fake')
vis.images(img.data.cpu().numpy()[:64]*0.5+0.5,win='real')
vis.plot('loss_g',loss_G_mean.data.cpu().numpy())
vis.plot('loss_d',loss_D_mean.data.cpu().numpy())
loss_add_g = torch.tensor(0.0, device=device)
loss_add_d = torch.tensor(0.0, device=device)
print('Generators: loss_G {} , Discriminators: loss_D {}'.format(loss_G_mean, loss_D_mean))
if (epoch + 1) % Config.save_every == 0:
torch.save(G.state_dict(),os.path.join(Config.save_model_G_path,'Anime_GAN_Glast.pth'))
torch.save(D.state_dict(), os.path.join(Config.save_model_D_path, 'Anime_GAN_Dlast.pth'))
数据的可视化参考了陈云大佬的代码,写了一个由visdom的实现的模块。
from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np
class Visualizer():
"""
封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
调用原生的visdom接口
"""
def __init__(self, env='default', **kwargs):
import visdom
self.vis = visdom.Visdom(env=env, use_incoming_socket=False,**kwargs)
# 画的第几个数,相当于横座标
# 保存(’loss',23) 即loss的第23个点
self.index = {}
self.log_text = ''
def reinit(self, env='default', **kwargs):
"""
修改visdom的配置
"""
self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs)
return self
def plot_many(self, d):
"""
一次plot多个
@params d: dict (name,value) i.e. ('loss',0.11)
"""
for k, v in d.items():
self.plot(k, v)
def img_many(self, d):
for k, v in d.items():
self.img(k, v)
def plot(self, name, y):
"""
self.plot('loss',1.00)
"""
x = self.index.get(name, 0)
self.vis.line(Y=np.array([y]), X=np.array([x]),
win=(name),
opts=dict(title=name),
update=None if x == 0 else 'append'
)
self.index[name] = x + 1
def img(self, name, img_):
"""
self.img('input_img',t.Tensor(64,64))
"""
if len(img_.size()) < 3:
img_ = img_.cpu().unsqueeze(0)
self.vis.image(img_.cpu(),
win=(name),
opts=dict(title=name)
)
def img_grid_many(self, d):
for k, v in d.items():
self.img_grid(k, v)
def img_grid(self, name, input_3d):
"""
一个batch的图片转成一个网格图,i.e. input(36,64,64)
会变成 6*6 的网格图,每个格子大小64*64
"""
self.img(name, tv.utils.make_grid(
input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0)))
def log(self, info, win='log_text'):
"""
self.log({'loss':1,'lr':0.0001})
"""
self.log_text += ('[{time}] {info}
'.format(
time=time.strftime('%m%d_%H%M%S'),
info=info))
self.vis.text(self.log_text, win=win)
def __getattr__(self, name):
return getattr(self.vis, name)
在5K张图片的训练下,模型的效果我个人认为还可以,不过这里面一些参数都是借鉴的陈云大佬的pytorch-book/chapter07-AnimeGAN at master · chenyuntc/pytorch-book · GitHub并没有经过很细致的调参。而大佬的数据是几万张头像,训练了200个epoch,我这里5K张训练了4000个epoch。在数据量小的情况下,4000个epoch也就大概跑五六个小时吧(2060)。
这是200个epoch的结果
1000个epoch
最终4000个epoch
虽然是基础的GAN,但在自己复现过程中还遇到了许多问题,也学到了很多。不得不说,GAN这个网络是真的有意思。在训练过程中,看着图片从一个个像素点逐渐变成二次元女孩子,有一种创作的感觉油然而生,虽然创作的是电脑。。。之后,我还会一步步学习经过改良的各种GAN网络。