【Tensorflow tf 掏粪记录】笔记四——tensorflow搭建GAN神经网络

GAN,对抗神经网络。简单通俗的说就是类似老顽童的左右互博术这种东西。机器试图自己打败自己,毕竟世界上最大的敌人是自己。哈哈哈哈。

2018/3/31 18:35 pm
下面是我最新训练生成的结果,网络深度为512,个别的类似3,5,2的复杂的数我是用1024深的网络指定标签针对训练的。
这里写图片描述这里写图片描述这里写图片描述这里写图片描述这里写图片描述这里写图片描述这里写图片描述这里写图片描述这里写图片描述这里写图片描述

简要介绍GAN


GAN神经网络主要就是在网络中有一个生成器(Generator)一个鉴别器(Discriminator),鉴别器鉴别真假以及预测标签。生成器则生成目标文件送去鉴别器鉴别。生成器努力生成鉴别器无法鉴别出为假的文件,鉴别器努力的试图鉴别出来送检的假货。然后就他们两个不断的对抗,就像炼蛊一样。

项目结构


这次的项目中主要有3个文件。使用的是MNIST数据集

  • config.yam:
    • 用来设置各种类似学习率(learning_rate), Leaky Relu所使用的alpha,迭代次数(epoch)等。我们可以指定训练的图像的标签。比如说可以单独训练标签为‘8’的数据,并生成类似‘8’的数据。方便我们调参。
  • train.py:
    • 放置训练使用的代码与神经网络。这里我是用了tensorboard,tensorboard将会把生成器(Generator)的loss与鉴别器(Discriminator)的loss分别展示到tensorboard中。我们可以在训练的过程中实时观察各个loss的走势。以及我把每次生成器(Generator)生成的图片都保存在了硬盘中,我们也可以实时的观察机器学习的成果。
  • utils.py:
    • 放置组成神经网络,预处理数据等train.py中需要用到的工具。例如data_batch()等。

Utils工具

生成器(Generator)


生成器我用的是全连接神经网络,因为MNIST数据集十分的简单,全连接神经网络的表现已经十分的优秀了。当然卷积也行,看个人喜好。

因为tensorflow中没有Leaky Relu函数,所以需要自己实现。所以tf.layer.dense()中的activation = None我不希望它自动调用了其他激活函数。

最后的激活函数我调用tf.tanh()

我把这段代码中的变量命名为‘generator’,为了方便我日后根据变量名保存变量。从而实现只保存生成器,丢弃鉴别器。

def generator( z, out_dim, n_units = 128, reuse = False, alpha = 0.01 ):
        with tf.variable_scope( 'generator', reuse = reuse ):    #命名闭包中代码的变量
            h1 = tf.layers.dense( z, n_units, activation = None )    # 声明全连接神经网络

            h1 = tf.maximum( alpha * h1, h1 )    # Leaky Relu
            logits = tf.layers.dense( h1, out_dim, activation = None )
            out = tf.tanh( logits )

         return out

鉴别器(Discriminator)


分类器用的激活函数也是Leaky Relu函数。结构跟生成器(Generator)类似,只不过最后的一层的激活函数换成了tf.sigmoid(),而且最一层全连接层的神经元数我只设置了一个,因为我发现tf.sigmoid()激活函数并没有对模型有很大的帮助。

同样的我也为这段代码中的变量命名了。

def discriminator( x, n_units = 128, reuse = False, alpha = 0.01 ):
    with tf.variable_scope( 'discriminator', reuse = reuse ):

        h1 = tf.layers.dense( x, n_units, activation = None )    

        # Leacy ReLU
        h1 = tf.maximum( alpha * h1, h1 )

        logits = tf.layers.dense( h1, 1, activation = None )
        out = tf.sigmoid( logits )

        return out, logits

train训练

建立网络


构建我们需要跑的网络,其中生成器(Generator)输入的是白噪音。这里只是占位符,还没初始化。

这里调用了两次鉴别器(Discriminator),其中第一次是输入真图片,并努力不误判为假图片。第二次是输入生成器(Generator)生成的假图片,尝试努力鉴别出为假图片。为了保证第一次与第二次所用的生成器的性能一致,我使用了reuse = True来保证tensorflow会重新使用相同的鉴别器(Discriminator)而不是重新new一个。

tf.reset_default_graph()

input_real, input_z = utils.model_inputs( input_size, z_size )    # 创建占位符

'''--------建立网络--------'''
g_model = utils.generator( input_z, input_size, n_units = g_hidden_size, alpha = alpha )    # 建立生成器

d_model_real, d_logits_real = utils.discriminator( input_real, n_units = d_hidden_size, alpha = alpha )    # 鉴别真图片

d_model_fake, d_logits_fake = utils.discriminator( g_model, reuse = True, n_units = d_hidden_size, alpha = alpha )    # 鉴别假图片

生成器与鉴别器的Loss的计算


在GAN神经网络中,需要计算3个Loss。分别是:生成器(Generator)的Loss,鉴别器(Discriminator)鉴别真图片的Loss,鉴别器(Discriminator)鉴别假图片的Loss

同时在鉴别真图片中把真标签中减去了smooth的值,为了帮助模型更好的泛化。

生成器的Loss计算:输入数据为鉴别器鉴别假图片的结果,输入标签为1。
在这里,对抗的性质体现的十分的明显。鉴别器努力的鉴别假图片,生成器努力的生成鉴别器鉴别不出来的假图片

d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits = d_logits_real,
                                                                       labels = tf.ones_like( d_logits_real ) * ( 1 - smooth ), name = 'd_loss_real' ) )    # 计算鉴别器鉴别真图片的Loss

d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits = d_logits_fake,
                                                                       labels = tf.zeros_like( d_logits_fake ), name = 'd_loss_fake' ) )    # 计算鉴别器鉴别假图片的Loss

d_loss = d_loss_real + d_loss_fake    # 鉴别器总Loss为两个鉴别器Loss的总和        

tf.summary.scalar('d_loss', d_loss)    # 把数据加入tensorboard

g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits = d_logits_fake,
                                                                  labels = tf.ones_like( d_logits_fake ), name = 'g_loss' ) )    # 计算生成器的Loss
tf.summary.scalar('g_loss', g_loss)    # 把数据加入tensorboard

工具的组合与运用


这里有个小坑,tensorflow自己下载的MNIST数据集,里面的数据最大值为1。所以很尴尬。我开始无论怎么训练的出来的都是一坨黑。

需要乘255才能在.jpg格式的图片中显示。

这里只保存生成器的变量,因为这个项目的重点是生成数据而不是鉴别数据。

batch_size = FLAGS.batch_size
epoches = FLAGS.epoches
samples = []
# losses = []
saver = tf.train.Saver( var_list = g_vars )    # 我们只保存生成器的变量
with tf.Session() as sess:
    merged, writer = utils.print_training_loss(sess)

    sess.run( tf.global_variables_initializer() )
    for e in range( epoches ):
        for batch in batches:
            batch_images = batch
            batch_images = batch_images * 2 - 1

            batch_z = np.random.uniform( -1, 1, size = ( batch_size, z_size ) )
            '''--------运行optimizers--------'''
            _ = sess.run( d_train_opt, feed_dict = {input_real : batch_images, input_z : batch_z} )
            _ = sess.run( g_train_opt, feed_dict = {input_z : batch_z} )

        train_loss_d = sess.run( d_loss, {input_z : batch_z, input_real : batch_images} )    # 打印Loss
        train_loss_g = g_loss.eval( {input_z : batch_z} )

        print( 'Epoch {}/{}...' . format( e + 1, epoches ),
               'Discriminator Loss: {:.4f}...' . format( train_loss_d ),
               'Generator Loss: {:.4f}' . format( train_loss_g ) )   

        # 加入tensorboard
        rs = sess.run(merged, feed_dict={input_z: batch_z, input_real: batch_images})
        writer.add_summary(rs, e)

        '''--------用当前的生成器生成图片并保存--------'''
        sample_z = np.random.uniform( -1, 1, size = ( 16, z_size ) )
        gen_samples = sess.run(
            utils.generator( input_z, input_size, n_units = g_hidden_size, reuse = True, alpha = alpha),
            feed_dict = {input_z : sample_z} )


        gen_image = gen_samples.reshape( ( -1, 28, 28, 1 ) )    # 把数组转换成high,width,channal形式
        gen_image = tf.cast( np.multiply( gen_image, 255 ), tf.uint8 )    # 数值乘255
        for r in range( gen_image.shape[0] ):
            with open( FLAGS.generate_file + str(e) + ' ' + str( r ) + '.jpg', 'wb' ) as img:
                img.write( sess.run( tf.image.encode_jpeg( gen_image[r] ) ) )    # 保存图片

        samples.append( gen_samples )
        saver.save( sess, './checkpoint/generator.ckpt' )

运行结果


配置好了config.yml直接跑train.py
这里写图片描述
然后开始漫长的等待。。。。。
看个人电脑配置把。一般20多分钟就结束了。
然后可以进入保存生成图片的文件夹可以看到一开始是一堆乱七八糟的东西【Tensorflow tf 掏粪记录】笔记四——tensorflow搭建GAN神经网络_第1张图片
大约在结束,才生成了我能看懂的东西
这里写图片描述
一个4, 一个9, 一个1

完整代码


train.py文件:

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import argparse
import os

import utils

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets( 'MNIST_data' )


'''--------Load the config file--------'''
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument( '-c', '--config', default = 'config.yml', help = 'The path to the config file' )

    return parser.parse_args()


args = parse_args()
FLAGS = utils.read_config_file( args.config )

if not( os.path.exists( FLAGS.generate_file ) ):
    os.makedirs( FLAGS.generate_file )


'''--------Preprocessing data--------'''
if( FLAGS.select_label != 'All' ):
    datas = utils.select_data( mnist, FLAGS.select_label )
else:
    datas = mnist.train.images    # shape ( 55000, 784 )

batches = utils.batch_data( datas, FLAGS.batch_size )


'''-----------Hyperparameters------------'''
# Size of input image to discriminator
input_size = 784
# Size of latent vector to genorator
z_size = 100
# Size of hidden layers in genorator and discriminator
g_hidden_size = FLAGS.g_hidden_size
d_hidden_size = FLAGS.d_hidden_size
# Leak factor for leaky ReLU
alpha = FLAGS.alpha
# Smoothing
smooth = 0.1


'''------------Build network-------------'''
tf.reset_default_graph()

# Creat out input placeholders
input_real, input_z = utils.model_inputs( input_size, z_size )

# Build the model
g_model = utils.generator( input_z, input_size, n_units = g_hidden_size, alpha = alpha )
# g_model is the generator output

d_model_real, d_logits_real = utils.discriminator( input_real, n_units = d_hidden_size, alpha = alpha )
d_model_fake, d_logits_fake = utils.discriminator( g_model, reuse = True, n_units = d_hidden_size, alpha = alpha )


'''---Discriminator and Generator Losses---'''
# Calculate losses
d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits = d_logits_real,
                                                                       labels = tf.ones_like( d_logits_real ) * ( 1 - smooth ), name = 'd_loss_real' ) )

d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits = d_logits_fake,
                                                                       labels = tf.zeros_like( d_logits_fake ), name = 'd_loss_fake' ) )
d_loss = d_loss_real + d_loss_fake
# add d_loss to summary scalar
tf.summary.scalar('d_loss', d_loss)

g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits = d_logits_fake,
                                                                  labels = tf.ones_like( d_logits_fake ), name = 'g_loss' ) )
# add g_loss to summary scalar
tf.summary.scalar('g_loss', g_loss)


'''---------------Optimizers----------------'''
# Optimizers
learning_rate = FLAGS.learning_rate

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith( 'generator' )]
d_vars = [var for var in t_vars if var.name.startswith( 'discriminator' )]

d_train_opt = tf.train.AdamOptimizer( learning_rate ).minimize( d_loss, var_list = d_vars )
g_train_opt = tf.train.AdamOptimizer( learning_rate ).minimize( g_loss, var_list = g_vars )


'''-----------------Traing---------------------'''
batch_size = FLAGS.batch_size
epoches = FLAGS.epoches
samples = []
# losses = []
# Only save generator variables
saver = tf.train.Saver( var_list = g_vars )
with tf.Session() as sess:
    # Tensorboard Print Loss
    merged, writer = utils.print_training_loss(sess)

    sess.run( tf.global_variables_initializer() )
    for e in range( epoches ):
        for batch in batches:
            batch_images = batch
            batch_images = batch_images * 2 - 1

            # Sample random noise for G
            batch_z = np.random.uniform( -1, 1, size = ( batch_size, z_size ) )

            # Run optimizers
            _ = sess.run( d_train_opt, feed_dict = {input_real : batch_images, input_z : batch_z} )
            _ = sess.run( g_train_opt, feed_dict = {input_z : batch_z} )

        # At the end of each epoch, get the losses and print them out
        train_loss_d = sess.run( d_loss, {input_z : batch_z, input_real : batch_images} )
        train_loss_g = g_loss.eval( {input_z : batch_z} )

        print( 'Epoch {}/{}...' . format( e + 1, epoches ),
               'Discriminator Loss: {:.4f}...' . format( train_loss_d ),
               'Generator Loss: {:.4f}' . format( train_loss_g ) )

        # Add data to tensorboard
        rs = sess.run(merged, feed_dict={input_z: batch_z, input_real: batch_images})
        writer.add_summary(rs, e)

        sample_z = np.random.uniform( -1, 1, size = ( 16, z_size ) )
        gen_samples = sess.run(
            utils.generator( input_z, input_size, n_units = g_hidden_size, reuse = True, alpha = alpha),
            feed_dict = {input_z : sample_z} )


        gen_image = gen_samples.reshape( ( -1, 28, 28, 1 ) )
        gen_image = tf.cast( np.multiply( gen_image, 255 ), tf.uint8 )
        for r in range( gen_image.shape[0] ):
            with open( FLAGS.generate_file + str(e) + ' ' + str( r ) + '.jpg', 'wb' ) as img:
                img.write( sess.run( tf.image.encode_jpeg( gen_image[r] ) ) )

        samples.append( gen_samples )
        saver.save( sess, './checkpoint/generator.ckpt' )

utils.py文件:

import tensorflow as tf
import yaml

def model_inputs( real_dim, z_dim ):
    inputs_real = tf.placeholder( tf.float32, ( None, real_dim ), name = 'input_real' )
    inputs_z = tf.placeholder( tf.float32, ( None, z_dim ), name = 'input_z' )

    return inputs_real, inputs_z

def generator( z, out_dim, n_units = 128, reuse = False, alpha = 0.01 ):
    with tf.variable_scope( 'generator', reuse = reuse ):
        # Hidden layer
        h1 = tf.layers.dense( z, n_units, activation = None )# Leaky ReLUense( z, n_units, activation = None )    # 全链接层的高级封装接口

        h1 = tf.maximum( alpha * h1, h1 )

        # Logits and tanh output
        logits = tf.layers.dense( h1, out_dim, activation = None )
        out = tf.tanh( logits )

        return out

def discriminator( x, n_units = 128, reuse = False, alpha = 0.01 ):
    with tf.variable_scope( 'discriminator', reuse = reuse ):
        # Hidden layer
        h1 = tf.layers.dense( x, n_units, activation = None )    # 全链接层的高级封装接口
        # Leacy ReLU
        h1 = tf.maximum( alpha * h1, h1 )

        logits = tf.layers.dense( h1, 1, activation = None )
        out = tf.sigmoid( logits )

        return out, logits

def print_training_loss( sess ):
    # tf.summary.scalar( 'd_loss', d_loss )
    # tf.summary.scalar( 'g_loss', g_loss )

    # merge all summary together
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter( 'logs/', sess.graph )

    return merged, writer

class Flag( object ):
    def __init__( self,**entries  ):
        self.__dict__.update( entries )

def read_config_file( config_file ):
    with open( config_file ) as f:
        FLAGES = Flag( **yaml.load( f ) )
    return FLAGES

def select_data( mnist, label ):
    data_idx = []
    for i in range( mnist.train.images.shape[0] ):
        if mnist.train.labels[i] == label:
            data_idx.append( i )
    datas = mnist.train.images[data_idx]
    return datas

def batch_data( datas, batch_size ):
    batches = []
    for i in range( datas.shape[0] // batch_size ):
        batch = datas[i * batch_size : ( i + 1 ) * batch_size, :]
        batches.append( batch )

    return batches

config.yml文件:

## set Basic configuration and alpha
g_hidden_size: 256
d_hidden_size: 256
alpha: 0.01

## set Batch size and epoche
batch_size: 100
epoches: 100

## set learning_rate
learning_rate: 0.002

## the path of the generate file
generate_file: generate/

## num to generate mast in range 9, or equ to "All"
select_label: 'All'

项目代码

https://github.com/IronMastiff/GAN_MNIST

你可能感兴趣的:(Tensorflow)