基于最基础的GAN生成动漫头像

最近在学习换脸相关的技术,在看完FaceShifter的论文和代码后就对GAN的思想产生了兴趣。而在看完陈云大佬的《深度学习框架PyTorch:入门与实践》GAN那一章后,就产生了用GAN生成动漫头像的念头。

基本思想

GAN,又叫生成对抗网络是一种非监督的学习。该网络中有一个生成器(Generators)和判别器(Discriminators),而训练过程就是这两个网络不断博弈对抗的过程。生成器不断生成假图企图通过判别器的识别,而判别器则将图片划分为真实图像和生成图像。

  • 在该模型中,生成器的输入是一串噪音,输出是一张生成的假图,而生成器致力于让判别器无法识别出这张假图是生成的图还是真实图像。在训练过程中,不断用判别器的分数做反馈使生成器效果越来越好。
  • 判别器的输入是一张图片,输出则是图片的分数,分数越高说明此时生成的图像越接近真实图像。判别器致力于识别图片是真图还是假图 ,在训练过程中不断投喂假图,输出一个分数再与真实图像的标签进行比较。实际上也是一个二分类的过程。

代码实现

获取数据

网上做这种模型的人非常多,所以动漫头像的数据集也非常多。不过我大致看了一下,网上的数据集中的动漫头像都非常古老,颇有90年代日本动漫的画风(可能都是老二次元)。在这个模型中,我用的是自己在一个网站爬下来的数据。网站链接在这/ | konachan.net - Konachan.com Anime Wallpapers。

爬虫

import time
import requests
import tqdm
from bs4 import BeautifulSoup
import os
import traceback# python异常模块

# 爬取图片
def download(url,filename,proxies):
    # 判断此时文件是否存在
    if os.path.exists(filename):
        print('file exists')
        return

    try:
        time.sleep(1)
        r = requests.get(url,stream=True,timeout=60,proxies=proxies)# 以流数据形式请求
        r.raise_for_status()
        with open(filename,'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:# 当这个文件存在时
                    f.write(chunk)
                    f.flush()
        return filename
    except KeyboardInterrupt:
        if os.path.exists(filename):# 此时出错说明该文件不存在任何数据,若保存过该文件则删除
            os.remove(filename)
        raise KeyboardInterrupt
    except Exception:
        traceback.print_exc()# 把返回信息输出到控制台
        if os.path.exists(filename):
            os.remove(filename)

if os.path.exists('imgs') is False:
    os.makedirs('imgs')

proxy = '127.0.0.1:58591' #
proxies = {
     'http': 'http://' + proxy,
     'https': 'https://' + proxy
 }

start = 1
end = 8000# 8k张图片
for i in tqdm.tqdm(range(start,end + 1),desc="download anime picture ing ~"):# tqdm括号内的必须是一个迭代器
    time.sleep(1)
    url =  'https://konachan.net/post?page=%d&tags=' % i# 网站
    html = requests.get(url,verify=True, proxies=proxies).text# 获取html网页上的内容
    soup = BeautifulSoup(html,'html.parser')
    for img in soup.find_all('img',class_="preview"):# 找到原网站中含有图片文件网站
        target_url = img['src']
        filename = os.path.join('imgs/true_imgs',target_url.split('/')[-1])
        download(target_url,filename,proxies)

可能是网站的原因,若不加sleep()会返回连接超时的报错,我猜可能是访问的太频繁了。不过具体原因我也不太清楚,对爬虫这一块不是很熟悉。

从这个网站爬下来的图片都是一些动漫壁纸,可能有些包含人物,而有些不包含。这里我用了openCV的一块模块来识别图像中的头像,并把它截取下来做为接下来训练的数据。

头像数据

# 从动漫壁纸中截取人物头像
import cv2
import sys
import os
from glob import glob

# 截取
def detect(filename,cascade_file = "lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)# 目标检测
    image = cv2.imread(filename)# 打开图片
    gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(gray,
                                     scaleFactor=1.1,
                                     minNeighbors=5,
                                     minSize=(48,48))

    for i ,(x,y,w,h) in enumerate(faces):
        face = image[y:y+h,x:x+w,:]# 得到图像像素点的分布
        face = cv2.resize(face,(96,96))
        save_filename = '{}-{}.jpg'.format(os.path.basename(filename).split('.')[0],i)
        cv2.imwrite("data/faces/"+save_filename,face)# 写入文件

if __name__ == '__main__':
    if os.path.exists('data/faces') is False:
        os.makedirs('data/faces')
    file_list = glob('imgs/true_imgs/*.jpg')# 将imgs中所有图片路径整合为一个迭代器
    for filename in file_list:
        detect(filename)

这两块,我都借鉴了这位大佬的代码利用GAN生成动漫头像_一个追逐自我的程序员的博客-CSDN博客

网络结构

生成器

# 生成器
class NetG(nn.Module):

    def __init__(self,opt):
        super(NetG,self).__init__()
        ngf = opt.ngf# 生成器feature map数

        # 生成器主要的网络模块
        self.main = nn.Sequential(
            # 输入是一个nz维的噪音,是一个随机生成的张量,可以认为是大小为1x1的feature amp
            nn.ConvTranspose2d(opt.nz,ngf*8,kernel_size=4,stride=1,padding=0,bias = False),# 反卷积,做上采样
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf*8,ngf*4,kernel_size=4,stride=2,padding=1,bias=False),# 继续上采样,不断减小图片维度
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf,3,kernel_size=5,stride=3,padding=1,bias=False),
            nn.Tanh()# 输出范围固定在 -1 ~ 1故而采用Tanh
            # 输出形状:3 x 96 x 96
            # feature map经过解码过程,最后生成一个图片
        )

    def forward(self,input):
        return self.main(input)

判别器

# 判别器
class NetD(nn.Module):

    def __init__(self,opt):
        super(NetD,self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入3*96*96即生成器生成的图片
            nn.Conv2d(3,ndf,kernel_size=5,stride=3,padding=1,bias=False),# 卷积,下采样,也是编码的过程
            nn.LeakyReLU(0.2,inplace=True),
            # 输出 ndf x 32 x32

            nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1,bias=False),# 正好将feature map图片大小缩小一半
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            # 输出 (ndf*2) x 16 x 16

            nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            # 输出 (ndf*4) x 8 x 8

            nn.Conv2d(ndf*4,ndf*8,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            # 输出 (ndf*8) x 4 x 4

            nn.Conv2d(ndf*8,1,kernel_size=4,stride = 1,padding=0,bias=False),# 最后编码成为一个维度为1的向量
            nn.Sigmoid()# 最后用Sigmoid作为分类,使得判别器成为一个判断二分类问题的模型,实际上判别器也是做一个二分类任务,判断是否为原图输出0或1
        )

    def forward(self,input):
        return self.main(input).view(-1)# 转成一个列向量,即sigmoid的结果在更前面的维度

我个人觉得,这个GAN中的下采样再上采样的过程应该也借鉴了U-Net的网络结构。不过这里是将噪音编码为图像,再将图像解码为一个score。

模型读取

def model(device,pth = False):
    netg = NetG(Config).to(device)
    netd = NetD(Config).to(device)
    if pth:
        netg.load_state_dict(torch.load(Config.load_G))
        netd.load_state_dict(torch.load(Config.load_D))
    return netg,netd

训练

参数

class Config():

    data_path = 'data/'
    num_workers = 4
    image_size = 96# 输入和输出的图片尺寸
    batch_size = 64
    max_epoch = 4000
    lr_G = 2e-4# 生成器的学习率
    lr_D = 2e-4# 判别器的学习率
    beta1 = 0.5# Adam优化器的beta1参数
    nz = 100# 产生的噪音维度
    ngf = 64# 生成器feature map数
    ndf = 64# 判别器feature map数

    save_img_path = 'generate_img'# 生成的图片保存路径
    save_model_G_path = 'ppppth/G'
    save_model_D_path = 'ppppth/D'

    load_D = 'ppppth/D/Anime_GAN_Dlast.pth'
    load_G = 'ppppth/G/Anime_GAN_Glast.pth'

    vis = True# 是否使用可视化
    env = 'GAN'
    plot_every = 20# 每间隔20 batch,visdom画图一次

    d_every = 1# 每一个batch训练一次判别器
    g_every = 5# 每五个batch训练一次生成器
    save_every = 20# 每20个epoch保存一次模型

    # 只测试不训练
    gen_img = 'imgs/generate_head/result.png'# 从512张生成的图片中保存最好的64张
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪声的均值
    gen_std = 1  # 噪声的方差

这里的dataset我使用的是torch自带的ImageFolder

import os
import torch
import visdom
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms as T
from model import NetD,NetG
from config import Config
from visualize import Visualizer

def train():
    # data
    data_path = Config.data_path
    image_size = Config.image_size
    transform = transforms()
    batch_size = Config.batch_size
    vis = Visualizer(Config.env)

    datasets = torchvision.datasets.ImageFolder(data_path,transform=transform)# 使用这个ImageFolder时,图片的路径必须是所处文件夹的上一级路径,即是data/而不是data/faces/
    dataloader = DataLoader(
        datasets,
        batch_size=batch_size,
        shuffle=True,
        num_workers=Config.num_workers,
        drop_last=True
    )

    # model
    device = 'cuda'
    G,D = model(device=device)

    # 优化器和损失函数
    lr_G = Config.lr_G# 生成器学习率
    lr_D = Config.lr_D# 判别器学习率
    beta = Config.beta1
    optimizer_G = torch.optim.Adam(G.parameters(),lr=lr_G,betas=(beta,0.999))
    optimizer_D = torch.optim.Adam(D.parameters(),lr=lr_D,betas=(beta,0.999))
    criterion = torch.nn.BCELoss().to(device)# 因为最终是一个二分类的问题

    # 标签,0为假图片,1为真图片
    t_label = torch.ones(batch_size).to(device)
    f_label = torch.zeros(batch_size).to(device)
    # 噪音,用于生成图片
    noise = torch.randn(batch_size,Config.nz,1,1).to(device)# 1x1大小的噪音
    val_noise = torch.randn(batch_size,Config.nz,1,1).to(device)

    epochs = Config.max_epoch
    loss_add_g = torch.tensor(0.0,device=device)
    loss_add_d = torch.tensor(0.0,device=device)

transform我没做什么特殊的数据增强

def transforms():
    transforms = T.Compose([
        T.CenterCrop(Config.image_size),
        T.ToTensor(),
        T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
    return transforms

 

开始训练

    for epoch in range(epochs):
        for i,(img,_) in enumerate(dataloader):
            img = img.to(device)

            # 致力于让生成的假图骗过判别器
            if (i+1) % Config.g_every == 0:
                # 训练生成器
                optimizer_G.zero_grad()
                noise.data.copy_(torch.randn(batch_size,Config.nz,1,1))# 使每次训练生成器时,噪音不同
                fake_img = G(noise)
                out = D(fake_img)
                loss = criterion(out,t_label)
                loss_add_g += loss
                loss.backward()
                optimizer_G.step()
                loss_G_mean = loss_add_g / (i+1)

            # 致力于让判别器能识别出真图和假图
            if (i+1) % Config.d_every == 0:
                # 训练判别器,训练判别器要训练两部分
                optimizer_D.zero_grad()

                # 尽可能让判别器识别图片为真
                real_output = D(img)
                loss_r = criterion(real_output,t_label)# 使判别器尽量识别出源图片是真图片
                loss_r.backward()

                # 尽可能让判别器识别为假
                noise.data.copy_(torch.randn(batch_size,Config.nz,1,1))
                fake_img = G(noise)# 根据噪音生成图片
                fake_output = D(fake_img)
                loss_f = criterion(fake_output,f_label)# 使判别器尽量识别出生成的图片是假的图片
                loss_f.backward()
                loss = loss_f + loss_r

                optimizer_D.step()
                loss_add_d += loss
                loss_D_mean = loss_add_d / (i+1)

            # 每隔plot_every个batch在visdom上画一次图
            if Config.vis and i % Config.plot_every == Config.plot_every - 1:
                generate_img = G(val_noise)
                vis.images(generate_img.detach().cpu().numpy()[:64]*0.5+0.5,win='fake')
                vis.images(img.data.cpu().numpy()[:64]*0.5+0.5,win='real')
                vis.plot('loss_g',loss_G_mean.data.cpu().numpy())
                vis.plot('loss_d',loss_D_mean.data.cpu().numpy())
                loss_add_g = torch.tensor(0.0, device=device)
                loss_add_d = torch.tensor(0.0, device=device)
                print('Generators: loss_G {} , Discriminators: loss_D {}'.format(loss_G_mean, loss_D_mean))

        if (epoch + 1) % Config.save_every == 0:
            torch.save(G.state_dict(),os.path.join(Config.save_model_G_path,'Anime_GAN_Glast.pth'))
            torch.save(D.state_dict(), os.path.join(Config.save_model_D_path, 'Anime_GAN_Dlast.pth'))

可视化

数据的可视化参考了陈云大佬的代码,写了一个由visdom的实现的模块。

from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np


class Visualizer():
    """
    封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
    调用原生的visdom接口
    """

    def __init__(self, env='default', **kwargs):
        import visdom
        self.vis = visdom.Visdom(env=env, use_incoming_socket=False,**kwargs)

        # 画的第几个数,相当于横座标
        # 保存(’loss',23) 即loss的第23个点
        self.index = {}
        self.log_text = ''

    def reinit(self, env='default', **kwargs):
        """
        修改visdom的配置
        """
        self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs)
        return self

    def plot_many(self, d):
        """
        一次plot多个
        @params d: dict (name,value) i.e. ('loss',0.11)
        """
        for k, v in d.items():
            self.plot(k, v)

    def img_many(self, d):
        for k, v in d.items():
            self.img(k, v)

    def plot(self, name, y):
        """
        self.plot('loss',1.00)
        """
        x = self.index.get(name, 0)
        self.vis.line(Y=np.array([y]), X=np.array([x]),
                      win=(name),
                      opts=dict(title=name),
                      update=None if x == 0 else 'append'
                      )
        self.index[name] = x + 1

    def img(self, name, img_):
        """
        self.img('input_img',t.Tensor(64,64))
        """

        if len(img_.size()) < 3:
            img_ = img_.cpu().unsqueeze(0)
        self.vis.image(img_.cpu(),
                       win=(name),
                       opts=dict(title=name)
                       )

    def img_grid_many(self, d):
        for k, v in d.items():
            self.img_grid(k, v)

    def img_grid(self, name, input_3d):
        """
        一个batch的图片转成一个网格图,i.e. input(36,64,64)
        会变成 6*6 的网格图,每个格子大小64*64
        """
        self.img(name, tv.utils.make_grid(
            input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0)))

    def log(self, info, win='log_text'):
        """
        self.log({'loss':1,'lr':0.0001})
        """

        self.log_text += ('[{time}] {info} 
'.format( time=time.strftime('%m%d_%H%M%S'), info=info)) self.vis.text(self.log_text, win=win) def __getattr__(self, name): return getattr(self.vis, name)

在5K张图片的训练下,模型的效果我个人认为还可以,不过这里面一些参数都是借鉴的陈云大佬的pytorch-book/chapter07-AnimeGAN at master · chenyuntc/pytorch-book · GitHub并没有经过很细致的调参。而大佬的数据是几万张头像,训练了200个epoch,我这里5K张训练了4000个epoch。在数据量小的情况下,4000个epoch也就大概跑五六个小时吧(2060)。

这是200个epoch的结果

 1000个epoch

最终4000个epoch

虽然是基础的GAN,但在自己复现过程中还遇到了许多问题,也学到了很多。不得不说,GAN这个网络是真的有意思。在训练过程中,看着图片从一个个像素点逐渐变成二次元女孩子,有一种创作的感觉油然而生,虽然创作的是电脑。。。之后,我还会一步步学习经过改良的各种GAN网络。

你可能感兴趣的:(pytorch,深度学习,cnn)