GAN以及TensorFlow-2.0 实现

有关GAN的部分可见有关GAN分类中的内容,这里不再赘述。下面主要看一下如何使用tensorflow-2.0来实现一个简单的GAN。

  • 首先我们需要定义模型中所需的一些超参数、损失函数和优化器
# hpy
BUFFER_SIZE = 60000
BATCH_SIZE = 256
EPOCHS = 50
z_dim = 100
num_examples_to_generate = 16

# loss and optimizers
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)
g_optimizer = keras.optimizers.Adam(1e-4)
d_optimizer = keras.optimizers.Adam(1e-4)
  • 定义一个生成器:这里使用keras中的序列模型构建一个简单的卷积网络模型
# generator
def make_generator():
    generator = keras.Sequential([
        keras.layers.Dense(7*7*256,use_bias = False,input_shape = (100,)),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        keras.layers.Reshape((7,7,256)),
        keras.layers.Conv2DTranspose(128,(5,5),strides = (1,1),padding = 'same',use_bias = False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        keras.layers.Conv2DTranspose(64,(5,5),strides = (2,2),padding = 'same',use_bias = False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        keras.layers.Conv2DTranspose(1,(5,5),strides = (2,2),padding = 'same',use_bias = False,activation = 'tanh'),
    ])

    return generator
  • 再定义一个判别器
def make_discriminator():
    discriminator = keras.Sequential([
        keras.layers.Conv2D(64,(5,5),strides = (2,2),padding = 'same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.2),
        keras.layers.Conv2D(128,(5,5),strides = (2,2),padding = 'same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.2),
        keras.layers.Flatten(),
        keras.layers.Dense(1),
    ])

    return discriminator
  • 分别建立生成器和判别器的损失函数,其中判别器的损失项包含两部分:对生成图像的判别和对真实图像的判别
# loss function 
def generator_loss(fake_iamge):
    return cross_entropy(tf.ones_like(fake_iamge),fake_iamge)

def discriminator_loss(fake_iamge,real_iamge):
    real_loss = cross_entropy(tf.ones_like(real_iamge),real_iamge)
    fake_loss = cross_entropy(tf.ones_like(fake_iamge),fake_iamge)
    return real_loss + fake_loss
  • 定义对于单批次数据的训练过程
# traing
@tf.function
def train_one_step(images):
    z = tf.random.normal([BATCH_SIZE,z_dim])
    
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        fake_images = g(z,training = True)
        
        real_pred = d(images,training = True)
        fake_pred = d(fake_images,training = True)
        
        g_loss = generator_loss(fake_images)
        d_loss = discriminator_loss(real_pred,fake_pred)
        
    g_gradients = g_tape.gradient(g_loss,g.trainable_variables)
    d_gradients = d_tape.gradient(d_loss,d.trainable_variables)
    
    g_optimizer.apply_gradients(zip(g_gradients,g.trainable_variables))
    d_optimizer.apply_gradients(zip(d_gradients,d.trainable_variables))
    
  • 对于整个数据集的训练过程
def train(dataset,epochs):
    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            train_one_step(image_batch)
            
        # display.clear_output(wait = True)
        generate_and_save_images(g,epoch + 1,seed)
        
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        print ('Time for epoch {} is {} sec'.format(epoch + 1,time.time() - start))
        
    # diplay_clear_output(wait = True)
    generate_and_save_images(g,epochs,seed)

由于设备的限制,这里难以给出输出结果,有GPU的可以跑一下。

完整的代码:

# -*- coding: utf-8 -*-
"""
Created on Sun Sep  8 16:17:50 2019

@author: dyliang
"""
from __future__ import absolute_import,print_function,division
import tensorflow as tf 
import tensorflow.keras as keras
import matplotlib.pyplot as plt 
import numpy as np 
import os
import PIL
import imageio
import glob
import time 

# hpy
BUFFER_SIZE = 60000
BATCH_SIZE = 256
EPOCHS = 50
z_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate,z_dim])

# load data
(train_images,train_labels),(_,_) = keras.datasets.mnist.load_data()

plt.imshow(train_images[0])
plt.show()

train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images - 127.5) / 127.5


train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# generator
def make_generator():
    generator = keras.Sequential([
        keras.layers.Dense(7*7*256,use_bias = False,input_shape = (100,)),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        keras.layers.Reshape((7,7,256)),
        keras.layers.Conv2DTranspose(128,(5,5),strides = (1,1),padding = 'same',use_bias = False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        keras.layers.Conv2DTranspose(64,(5,5),strides = (2,2),padding = 'same',use_bias = False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        keras.layers.Conv2DTranspose(1,(5,5),strides = (2,2),padding = 'same',use_bias = False,activation = 'tanh'),
    ])

    return generator

# test 
g = make_generator()
z = tf.random.normal([1,100])
fake_image = g(z,training = False)
plt.imshow(fake_image[0,:,:,0],cmap='gray')

def make_discriminator():
    discriminator = keras.Sequential([
        keras.layers.Conv2D(64,(5,5),strides = (2,2),padding = 'same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.2),
        keras.layers.Conv2D(128,(5,5),strides = (2,2),padding = 'same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.2),
        keras.layers.Flatten(),
        keras.layers.Dense(1),
    ])

    return discriminator

d = make_discriminator()
pred = d(fake_image)
print ('pred score is: ',pred)

# loss and optimizers
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)
g_optimizer = keras.optimizers.Adam(1e-4)
d_optimizer = keras.optimizers.Adam(1e-4)

# loss function 
def generator_loss(fake_iamge):
    return cross_entropy(tf.ones_like(fake_iamge),fake_iamge)

def discriminator_loss(fake_iamge,real_iamge):
    real_loss = cross_entropy(tf.ones_like(real_iamge),real_iamge)
    fake_loss = cross_entropy(tf.ones_like(fake_iamge),fake_iamge)
    return real_loss + fake_loss

# checkpoint
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt")
checkpoint = tf.train.Checkpoint(g_optimizer = g_optimizer,
                                 d_optimizer = d_optimizer,
                                 g = g,
                                 d = d)

# traing
@tf.function
def train_one_step(images):
    z = tf.random.normal([BATCH_SIZE,z_dim])
    
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        fake_images = g(z,training = True)
        
        real_pred = d(images,training = True)
        fake_pred = d(fake_images,training = True)
        
        g_loss = generator_loss(fake_images)
        d_loss = discriminator_loss(real_pred,fake_pred)
        
    g_gradients = g_tape.gradient(g_loss,g.trainable_variables)
    d_gradients = d_tape.gradient(d_loss,d.trainable_variables)
    
    g_optimizer.apply_gradients(zip(g_gradients,g.trainable_variables))
    d_optimizer.apply_gradients(zip(d_gradients,d.trainable_variables))
    
    
def train(dataset,epochs):
    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            train_one_step(image_batch)
            
        # display.clear_output(wait = True)
        generate_and_save_images(g,epoch + 1,seed)
        
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        print ('Time for epoch {} is {} sec'.format(epoch + 1,time.time() - start))
        
    # diplay_clear_output(wait = True)
    generate_and_save_images(g,epochs,seed)

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

            
if __name__ == '__main__':
    train(train_dataset,EPOCHS)
    

你可能感兴趣的:(GAN)