图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

文章目录

  • 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)
    • 1.前言
    • 2.GAN网络主体架构介绍
    • 3.模型搭建
      • 3.1 生成器的搭建
      • 3.2 判别器模型的搭建
    • 4.数据预处理
    • 5.定义训练各项参数以及训练步骤
    • 6.训练以及效果可视化评估
    • 结语

1.前言

​ 深度学习,中有一个较为成熟并且非常重要的方向,GAN图像对抗生成网络,该网络在图像生成,图像增强,风格化领域,以及在艺术的图像创造(博主也是在看到一个关于中国山水画的GAN生成上,有了学习GAN的兴趣)有重要的作用。

​ 那么正所谓柿子要挑软的捏,学习从最简单的开始,在GAN方面完全是萌新的博主,今天介绍的自然也不是什么太难的架构,在本篇博客中,我会介绍GAN的大致架构,并用较为简单的方式从头到尾 (模型搭建,定义训练参数,训练步骤)实现他,如果本篇博客对你有帮助的话,别忘记点个赞。

( ̄▽ ̄)~■干杯□~( ̄▽ ̄)

2.GAN网络主体架构介绍

GAN的网络总体架构其实非常简单,他的中文名字对抗生成网络,意思是在他模型中包含两个网络,生成网络,对抗网络,总体结构如下图:

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)_第1张图片

​ 我们可以看到这张图上包含了两个网络Generator图像生成器和Discriminator图像分辨器,他的工作原理简单来说是这样的:我们的目标是想要一个生成图片那么我们如何去训练这个呢,这里GAN的开发者提出了这么一个想法,我们训练一个判别器,训练一个生成器,输入噪声(也就是我们提前规定好形状的随机初始化的向量)然后产生了图片,然后我们将真实的图片与虚假图片一起输入判别器,判别图片是否是真实的,利用在这里产生的损失去训练生成器,与判别器。那么我们可以想想如果这样的话我们最终产生的理想结果就是,判别器最终无法判别生成器生成的图片是真是假,最终预测的概率只有0.5(真假 二分类随机乱猜的概率)。

​ 当刚看懂网络工作方式的时候,我简直惊呆了,这是多么神奇的思维啊,生成器在训练中由于损失控制会努力希望生成的图片被判别为真,而判别器是希望能完全给出正确的判断(给生成的图片的判断全为0,真的图片判断全为1),那么在这两个模型的训练之间,他们在互相对抗,我们最终得到的将会是一个非常好的图像生成器,和 自编码器相比,(直接计算生成图与原图的差距)效果会更好(这里我会在之后的mnist数据生成展示中展示编码器与GAN网络的差别)。

3.模型搭建

​ 那么在介绍完模型之后,我们趁热打铁,直接开始模型的搭建,在上文中,我提到了两个模型负责生成图片的生成器,负责判别图片真假的判别器,接下来我开始分别搭建这两个模型。

3.1 生成器的搭建

​ 这里我们要搭建的是一个能够接收我们产生的随机初始化的向量,然后产生图片(这里我们产生的数据是mnist的手写体数字)的模型,这里我为了简单化全部采用全连接层来写模型

import tensorflow as tf
keras=tf.keras
layers=keras.layers
def generator_model():
    model=keras.Sequential()#
    model.add(layers.Dense(256,input_shape=(100,),use_bias=False))#输入形状100是我输入噪声的形状,生成器一般都不使用BIAS
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#GAN中一般使用LeakyRelu函数来激活
    model.add(layers.Dense(512,use_bias=False))#生成器一般都不使用BIAS
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))#生成能够调整成我们想要的图片形状的向量
    model.add(layers.BatchNormalization())
    model.add(layers.Reshape((28,28,1)))#这里进行修改向量的形状,可以直接使用layers的reshape
    return model

可以看到经过全连接层的这样处理,我们输出的会是一个形状大小为(28,28,1)的图片,那么生成器的任务就是判断输入图片是否是生成的,也就是输入图片,输出0,1一个非常简单的二分类问题,那么我们就按照这个思路搭建我们的判别器网络。

3.2 判别器模型的搭建

判别器这里我也使用最基础的全连接层来创建(一方面是减少计算量,一方面是测试一下Dense层的效果)

def discriminator_model():
    model=keras.Sequential()
    model.add(layers.Flatten())#图片是一个三维数据,要输入到全连接层之前,先使用flatten层压平为一维的
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Dense(256,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Dense(1))#最后输出为0,1只需要一层
    return model

那么定义完组成模型的两个重要架构之后,我们为了接下来的训练需要准备处理好的数据,所以这里我们开始处理数据。

4.数据预处理

在本篇最简单的实战中,我采用深度学习中使用次数最多,入门级Hello World数据集,mnist手写体数据集,由数万张手写体数字组成

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)_第2张图片

这里的数字都是手写之后,经过特殊处理最终保存下来的。可以看到这样生成的图片是非常带有个人风格的(这写的也不太整齐。。),那么利用GAN生成网络去生成能类似人写的数据,达到可以欺骗人眼的效果,就是我此次的目的,那么废话少说,就开始我们此次的数据准备。

(x_train,y_train),_=keras.datasets.mnist.load_data()
x_train=tf.expand_dims(x_train,axis=-1)#这里由于输入的手写体是只有两个维度的,所以这里我扩展最后一个维度
x_train.shape
TensorShape([60000, 28, 28, 1])

扩展完维度后,为了方便模型运算,我们需要将数据进行归一化,规定数据集的BATCH_SIZE

x_train=tf.cast(x_train,tf.float32)
x_train=x_train/255.0
x_train=x_train*2-1#将图片数据规范到[-1,1]
BATCH_SIZE=256
BUFFER_SIZE=60000#每次训练弄乱的大小
dataset=tf.data.Dataset.from_tensor_slices(x_train)
dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

5.定义训练各项参数以及训练步骤

在搭建完模型,预处理好数据之后,接下来就需要定义模型所需优化器,损失计算函数,以及训练步骤

loss_object=keras.losses.BinaryCrossentropy(from_logits=True)#损失这里使用二分类交叉熵损失,没有激活是logits
def discriminator_loss(real_out,fake_out):     real_loss=loss_object(tf.ones_like(real_out),  real_out)
    fake_loss=loss_object(tf.zeros_like(fake_out),fake_out)
    return real_loss+fake_loss
#这里判别器使用的损失是计算我们人为制造的0,1标签与判别器模型输出的做计算,最终返回二者相加
def generator_loss(fake_out):
    fake_loss=loss_object(tf.ones_like(fake_out),fake_out)
    return fake_loss
#生成器计算损失当然是希望判别器都把他当真,所以是与1做计算

generator_opt=keras.optimizers.Adam(1e-4)
discriminator_opt=keras.optimizers.Adam(1e-4)#定义两个模型的优化器

定义完了在训练吗中需要用到的优化器损失函数,我们这里接下来定义模型,训练步骤并开始训练(这里我们会在每次训练后绘画随机生成的图片,来观察我们图像生成模型的效果,所以这里我会提前制作一个随机种子)

EPOCHS=100
noise_dim=100 #输入噪声的维度
num=16 #每次随机绘画16张图
seed=tf.random.normal(shape=([num,noise_dim])) #制作用于生成图片的向量
gen_model=generator_model()
dis_model=discriminator_model()
#初始化这两个模型
#定义训练步骤
@tf.function
def train_step(images):
    noise=tf.random.normal([BATCH_SIZE,noise_dim])
    with tf.GradientTape() as gentape, tf.GradientTape() as disctape:
        real_output=dis_model(images,training=True)
        fake_image=gen_model(noise,training=True)
        fake_output=dis_model(fake_image,training=True)
        gen_loss=generator_loss(fake_output)
        dis_loss=discriminator_loss(real_output,fake_output)
    grad_gen=gentape.gradient(gen_loss,gen_model.trainable_variables)
    grad_dis=disctape.gradient(dis_loss,dis_model.trainable_variables)
    generator_opt.apply_gradients(zip(grad_gen,gen_model.trainable_variables))
    discriminator_opt.apply_gradients(zip(grad_dis,dis_model.trainable_variables))
    
#在每次训练后绘图
def generate_plot_img(gen_model,test_noise):
    pre_img=gen_model(test_noise,training=False)
    fig=plt.figure(figsize=(4,4))
    for i in range(pre_img.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_img[i, :, :, 0]+1)/2,cmap='gray')
        #这里cmap限定绘图的颜色空间,灰度图
        plt.axis('off')
    plt.show()#将16张图片一起显示出来

6.训练以及效果可视化评估

那么我们开始训练

def train(dataset, epochs):
    for epoch in range(epochs):
        for img in dataset:
            train_step(img)
            print('-',end='')
        generate_plot_img(gen_model,seed)#绘制图片
train(dataset,EPOCHS)#这里EPOCHS我设置为100

那么由于我的随机数种子是固定的,所以这里我们随机生成的图片每次都是固定的数字,所以我们是可以看到效果在不断变好,如下

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)_第3张图片

这是第一次训练结束后生成的一团浆糊

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)_第4张图片

这是第五次训练产生的图像,可以看到已经渐渐产生了有数字的轮廓,

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)_第5张图片

在经过100次训练后最终我们看到我们的图像生成器,最后产生的图片已经非常有手写数字的轮廓。

虽然效果仍然不是很好,但其实是由于我这里完全使用了全连接层,在图像处理领域使用卷积神经网络会更好的效果,下图是我使用了卷积神经网络后的效果:

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)_第6张图片

结语

在本篇博客中,我完成了一个非常简单的GAN生成对抗网络,并训练该模型使得他可以生成非常接近的手写体的真实数据,对本篇博客有疑问或者建议的同学欢迎评论区交流。

你可能感兴趣的:(深度学习,tensorflow,神经网络)