Python基于Tensorflow实现DCGAN-动漫头像生成

目录

前言

DCGAN简介

python代码

1. 导入python包、定义全局变量

2. 读取数据

3. 搭建生成器generator

 4. 搭建判别器discriminator

5.  搭建GAN网络

6. 主函数

7. 补充:直接读模型生成图像

运行结果

1. loss曲线

2. 生成器生成动漫头像

完整代码、数据集、模型下载地址:


前言

去年做大创项目摸索方法时尝试了DCGAN,但后来放弃了这个方法,也没有太深入的研究,前段时间看到一个Deep Fake相关的比赛,突然对GAN网络产生了兴趣,又把以前DCGAN的代码重新整理了一遍,有什么理解有误的,欢迎批评指正。

        数据集和模型在文末有github下载链接。

DCGAN简介:

DCGAN原论文地址:1511.06434.pdf (arxiv.org)

DCGAN是对Goodfellow在2014年提出的原始GAN的一种改进:

(1)使用指定步长的卷积层代替池化层

(2)生成器和判别器中都使用BN

(3)移除全连接层

(4)生成器除去输出层采用Tanh外,全部使用ReLU作为激活函数

(5)判别器所有层都使用LeakyReLU作为激活函数

生成器(Generator)和判别器(Discriminator)模型如下:

Python基于Tensorflow实现DCGAN-动漫头像生成_第1张图片

DCGAN的基本流程如下,生成器通过反卷积生成虚假图像,判别器通过卷积对真实图像和虚假图像进行判别。

Python基于Tensorflow实现DCGAN-动漫头像生成_第2张图片

生成器总是希望自己能生成能以假乱真的图像迷惑判别器,而判别器总是希望自己能练就火眼金睛识破生成器的伎俩,两者相互对抗,相互竞争,不断地提高了自己的能力。

在不断的训练过程中,生成器一直通过判别器给出的loss调整自己,生成的图像也越来越接近真实图像,最终使得判别器误将生成的图像预测为真实图像,这也就达到了我们的目标:利用生成器生成有价值的仿真图像。

python代码

1. 导入python包、定义全局变量

import numpy as np 
import pandas as pd 
from PIL import Image
import matplotlib.pyplot as plt
import os
import keras
from keras import preprocessing
from keras.models import Sequential
from keras.layers import Conv2D,Dropout,Dense,Flatten,Conv2DTranspose,BatchNormalization,LeakyReLU,Reshape
import tensorflow as tf
import cv2
from tqdm import tqdm

noise_shape = 100
epochs = 300
batch_size = 128

2. 读取数据

def load_images():
    path_celeb = []
    train_path_celeb="/kaggle/input/face-data/data/"
    for path in os.listdir(train_path_celeb):
        if '.png' in path:
            path_celeb.append(os.path.join(train_path_celeb, path))
    images=[np.array(Image.open(onepath)) for onepath in  path_celeb]
    #min-max标准化
    for i in range(len(images)):
        images[i] = ((images[i] - images[i].min())/(images[i].max() - images[i].min()))  
    images = np.array(images)
    return images

3. 搭建生成器generator

DCGAN的论文里提到生成器使用ReLU作为激活函数,最后一层使用tanh,我最开始照着他的来效果不好,后来我把激活函数改成下面这样,效果反而好多了,可能数据集不同的原因?

# 生成器,将输入的1*100噪声生成为64*64*3的动漫图片
def get_generator():
    generator=Sequential()
    generator.add(Dense(4*4*512,input_shape=[noise_shape]))
    generator.add(Reshape([4,4,512]))
    generator.add(Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"))
    generator.add(LeakyReLU(alpha=0.2))
    generator.add(BatchNormalization())
    generator.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    generator.add(LeakyReLU(alpha=0.2))
    generator.add(BatchNormalization())
    generator.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"))
    generator.add(LeakyReLU(alpha=0.2))
    generator.add(BatchNormalization())
    generator.add(Conv2DTranspose(3, kernel_size=4, strides=2, padding="same",activation='sigmoid'))
    return generator

 4. 搭建判别器discriminator

# 判别器,判断输入的动漫头像图片是真还是假
def get_discriminator():
    discriminator=Sequential()
    discriminator.add(Conv2D(32, kernel_size=4, strides=2, padding="same",input_shape=[64,64, 3]))
    discriminator.add(Conv2D(64, kernel_size=4, strides=2, padding="same"))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(BatchNormalization())
    discriminator.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(BatchNormalization())
    discriminator.add(Conv2D(256, kernel_size=4, strides=2, padding="same"))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Flatten())
    discriminator.add(Dropout(0.5))
    discriminator.add(Dense(1,activation='sigmoid'))
    return discriminator

5.  搭建GAN网络

# 将生成器和判别器组合为GAN网络
def get_gan(generator,discriminator):
    GAN =Sequential([generator,discriminator])
    discriminator.compile(optimizer='adam',loss='binary_crossentropy')
    #discriminator.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
                        #  loss='binary_crossentropy')
    discriminator.trainable = False
    GAN.compile(optimizer='adam',loss='binary_crossentropy')
    #GAN.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
               # loss='binary_crossentropy')
    #GAN.layers
    return GAN

6. 主函数

def main():
    #读取数据集
    train_data=load_images()
    # 可视化部分图片
    plt.figure(figsize=(10,10))
    fig,ax=plt.subplots(2,5)
    fig.suptitle("Real Images")
    idx=100
    for i in range(2):
        for j in range(5):
            ax[i,j].imshow(train_data[idx].reshape(64,64,3))
            idx+=1            
    plt.tight_layout()
    plt.show()
    
    # 初始化网络模型
    generator=get_generator()
    # 查看网络结构
    generator.summary()
    discriminator=get_discriminator()
    discriminator.summary()
    GAN=get_gan(generator,discriminator)
    GAN.summary()
    
    # 判别器loss
    D_loss=[] 
    # 生成器loss
    G_loss=[] 
    # tqdm提供进度条显示的功能
    with tqdm(total=epochs) as _tqdm:
        for epoch in range(epochs):
            # 设置进度条前缀
            _tqdm.set_description('epoch: {}/{}'.format(epoch , epochs))
            for i in range(train_data.shape[0]//batch_size):
                # 生成随机噪声
                noise=np.random.uniform(-1,1,size=[batch_size,noise_shape])
                # 将噪声输入生成器,生成假图片
                gen_image = generator.predict_on_batch(noise)
                # 分批次取数据训练
                train_dataset = train_data[i*batch_size:(i+1)*batch_size]
                # 生成真图片的标签
                train_label=np.ones(shape=(batch_size,1))
                discriminator.trainable = True
                # 将真图片输入判别器进行预测
                d_loss1 = discriminator.train_on_batch(train_dataset,train_label)
    
                #将生成器生成的假图片输入判别器进行预测
                train_label=np.zeros(shape=(batch_size,1))
                d_loss2 = discriminator.train_on_batch(gen_image,train_label)
                # 再生成随机噪声,用以训练生成器
                noise=np.random.uniform(-1,1,size=[batch_size,noise_shape])
                train_label=np.ones(shape=(batch_size,1))
                # 此时不再训练判别器
                discriminator.trainable = False
                #训练生成器
                g_loss = GAN.train_on_batch(noise, train_label)
                D_loss.append(d_loss1+d_loss2)
                G_loss.append(g_loss)
            # 每5个epoch查看一次生成器生成效果
            if epoch % 5 == 0:
                samples = 10
                x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples,100)))
                for k in range(samples):
                    plt.subplot(2, 5, k+1)
                    plt.imshow(x_fake[k].reshape(64,64,3))
                    plt.xticks([])
                    plt.yticks([])
                plt.tight_layout()
                plt.show()
            # 设置进度条后缀        
            _tqdm.set_postfix(D_loss = '{:.3f}'.format(d_loss1+d_loss2),  G_loss = '{:.3f}'.format( g_loss))
            # 进度条更新
            _tqdm.update(1)
    print('Training is complete')
    
    # 保存训练好的生成器,判别器模型
    generator.save('generator300.h5')
    discriminator.save('discriminator300.h5')
    print('models are saved')
    
    # 
    for i in range(5):
        plt.figure(figsize=(7,7))   
        for k in range(20):
            noise=np.random.uniform(-1,1,size=[20,noise_shape])
            im=generator.predict(noise)
            # 对生成结果降噪,下面三个分别是高斯滤波、中值滤波、不进行滤波,三个选择一个取消注释就行,可以尝试摸索一下其他降噪方法,看看哪种效果好
            #im= cv2.GaussianBlur(im[k], (3, 3),0.5)
            #im=cv2.medianBlur(im[k], 3)
            im=im[k]
            plt.subplot(5, 4, k+1)
            plt.imshow(im.reshape(64,64,3))
            plt.xticks([])
            plt.yticks([])
        plt.tight_layout()
        plt.show()
    # 可视化loss    
    plt.figure(figsize=(10,10))
    plt.plot(G_loss,color='red',label='Generator_loss')
    plt.plot(D_loss,color='blue',label='Discriminator_loss')
    plt.legend()
    plt.xlabel('total batches')
    plt.ylabel('loss')
    plt.title('Model loss per batch')
    plt.show()

7. 补充:直接读模型生成图像

import matplotlib.pyplot as plt
import keras


generator=keras.models.load_model("../input/gan-model/generator.h5")

for i in range(5):
        plt.figure(figsize=(7,7))   
        for k in range(20):
            # 噪声范围不同生成的图像会有明显的区别,可以尝试设为(-1,0)比较和(-1,1的区别)
            noise=np.random.uniform(-1,0.5,size=[20,100])
            im=generator.predict(noise)
            # 减噪
            #im= cv2.GaussianBlur(im[k], (3, 3),0.8)
            #im=cv2.medianBlur(im[k], 3)
            im=im[k]
            plt.subplot(5, 4, k+1)
            plt.imshow(im.reshape(64,64,3))
            plt.xticks([])
            plt.yticks([])
        plt.tight_layout()
        plt.show()

运行结果

1. loss曲线

好像也看不出个什么?原论文里也没展示loss曲线。

Python基于Tensorflow实现DCGAN-动漫头像生成_第3张图片

2. 生成器生成动漫头像

训练了300个epoch,最终生成的图片如下,感觉效果还行,在最后预测时设置不同的噪声范围,生成的图像会有明显的区别,比如在(-1,0)的发色都是蓝色,(0,1)之间的发色都是黄色,可以自己调着试试。

关于降噪,生成的图像确实不够清晰,用高斯模糊处理给后噪点少了但是很确实模糊,以后有时间了尝试用超分辨算法来处理,下面的都是没有进行降噪操作的图片:

Python基于Tensorflow实现DCGAN-动漫头像生成_第4张图片

Python基于Tensorflow实现DCGAN-动漫头像生成_第5张图片

数据集、模型下载地址:

face_data | Kaggle

gan_model | Kaggle

你可能感兴趣的:(人工智能学习,深度学习,人工智能,tensorflow,cnn,python)