pix2pix tensorflow personal

import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import PIL
from IPython.display import clear_output
path_to_zip = tf.keras.utils.get_file("facades.tar.gz", cache_subdir = os.path.abspath('.'),
                                     origin="https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz",
                                     extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip),'facades/')
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

def load_image(image_file, is_train):
    image = tf.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    
    
    w = tf.shape(image)[1]
    
    w = w//2
    
    real_image = image[:, :w, :]
    input_image = image[:, w:, :]
    
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)
    
    if is_train:
        input_image = tf.image.resize_images(input_image, [286,286],
                                             align_corners=True,
                                             method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        real_image = tf.image.resize_images(real_image, [286, 286],
                                            align_corners=True,
                                            method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        stacked_image = tf.stack([input_image, real_image], axis = 0)
        cropped_image = tf.random_crop(stacked_image, size = [2, IMG_HEIGHT, IMG_WIDTH, 3])
        input_image, real_image = cropped_image[0], cropped_image[1]
        
        if np.random.random() > 0.5:
            input_image = tf.image.flip_left_right(input_image)
            real_image = tf.image.flip_left_right(real_image)
    else:
        input_image = tf.image.resize_images(input_image, size = [IMG_HEIGHT, IMG_WIDTH],
                                            align_corners = True, method = 2)
        real_image = tf.image.resize_images(real_image, size = [IMG_HEIGHT, IMG_WIDTH],
                                           align_corners = True, method = 2)
        
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) -1
        
    return input_image, real_image
        
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(lambda x:load_image(x, True))
train_dataset = train_dataset.batch(1)
test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
test_dataset = test_dataset.map(lambda x:load_image(x, False))
test_dataset = test_dataset.batch(1)
OUTPUT_CHANNELS = 3
class Downsample(tf.keras.Model):
    def __init__(self, filters, size, apply_batchnorm = True):
        super(Downsample, self).__init__()
        self.apply_batchnorm = apply_batchnorm
        initializer = tf.random_normal_initializer(0., 0.02)
        
        self.conv1 = tf.keras.layers.Conv2D(filters, 
                                            (size,size),
                                           strides=2,
                                            padding='same',
                                           kernel_initializer= initializer,
                                            use_bias=False
                                           )
        if self.apply_batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()
    def call(self, x, training):
        x = self.conv1(x)
        if self.apply_batchnorm:
            x = self.batchnorm(x, training = training)
        x = tf.nn.leaky_relu(x)
        return x

class Upsample(tf.keras.Model):
    def __init__(self, filters, size, apply_dropout = False):
        super(Upsample, self).__init__()
        self.apply_dropout = apply_dropout
        initializer = tf.random_normal_initializer(0., 0.02)
        
        self.up_conv = tf.keras.layers.Conv2DTranspose(filters,
                                                      (size,size),
                                                      strides = 2,
                                                      padding = 'same',
                                                      kernel_initializer = initializer,
                                                      use_bias = False)
        self.batchnorm = tf.keras.layers.BatchNormalization()
        if self.apply_dropout:
            self.dropout = tf.keras.layers.Dropout(0.5)
        
    def call(self, x1, x2, training):
        x = self.up_conv(x1)
        x = self.batchnorm(x, training = training)
        if self.apply_dropout:
            x = self.dropout(x, training = training)
        x = tf.nn.relu(x)
        x = tf.concat([x, x2], axis = -1)
        return x

class Generator(tf.keras.Model):
    
    def __init__(self):
        super(Generator, self).__init__()
        initializer = tf.random_normal_initializer(0., 0.02)
        
        self.down1 = Downsample(64, 4, apply_batchnorm = False)
        self.down2 = Downsample(128, 4)
        self.down3 = Downsample(256, 4)
        self.down4 = Downsample(512, 4)
        self.down5 = Downsample(512, 4)
        self.down6 = Downsample(512, 4)
        self.down7 = Downsample(512, 4)
        self.down8 = Downsample(512, 4)
        
        self.up1 = Upsample(512, 4, apply_dropout = True)
        self.up2 = Upsample(512, 4, apply_dropout = True)
        self.up3 = Upsample(512, 4, apply_dropout = True)
        self.up4 = Upsample(512, 4)
        self.up5 = Upsample(256, 4)
        self.up6 = Upsample(128, 4)
        self.up7 = Upsample(64, 4)
        
        self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS,
                                                   (4, 4),
                                                   strides = 2,
                                                   padding = 'same',
                                                   kernel_initializer = initializer)
        
    @tf.contrib.eager.defun
    def call(self, x, training):
        x1 = self.down1(x, training = training)
        x2 = self.down2(x1, training = training)
        x3 = self.down3(x2,training = training)
        x4 = self.down4(x3, training = training)
        x5 = self.down5(x4, training = training)
        x6 = self.down6(x5, training = training)
        x7 = self.down7(x6, training = training)
        x8 = self.down8(x7, training = training)
        
        x9 = self.up1(x8, x7, training = training)
        x10 = self.up2(x9, x6, training = training)
        x11 = self.up3(x10, x5, training = training)
        x12 = self.up4(x11, x4, training = training)
        x13 = self.up5(x12, x3, training = training)
        x14 = self.up6(x13, x2, training = training)
        x15 = self.up7(x14, x1, training = training)
        
        x16 = self.last(x15)
        x16 = tf.nn.tanh(x16)
        
        return x16
class DiscDownsample(tf.keras.Model):
    def __init__(self, filters, size, apply_batchnorm = True):
        super(DiscDownsample, self).__init__()
        self.apply_batchnorm = apply_batchnorm
        initializer = tf.random_normal_initializer(0., 0.02)
        
        self.conv1 = tf.keras.layers.Conv2D(filters,
                                                       (size, size),
                                                       strides = 2,
                                                       padding = 'same',
                                                       kernel_initializer = initializer,
                                                       use_bias = False)
        if self.apply_batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()
        
    def call(self, x, training):
        x = self.conv1(x)
        if self.apply_batchnorm:
            x = self.batchnorm(x, training = training)
        x= tf.nn.leaky_relu(x)
        return x

class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        initializer = tf.random_normal_initializer(0., 0.02)
        
        self.down1 = DiscDownsample(64, 4, False)
        self.down2 = DiscDownsample(128, 4)
        self.down3 = DiscDownsample(256, 4)
        
        self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
        self.conv = tf.keras.layers.Conv2D(512,
                                          (4, 4),
                                          strides = 1,
                                          kernel_initializer = initializer,
                                          use_bias = False)
        self.batchnorm1 = tf.keras.layers.BatchNormalization()
        
        self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
        self.last = tf.keras.layers.Conv2D(1,
                                          (4, 4),
                                          strides = 1,
                                          kernel_initializer = initializer)
        
    @tf.contrib.eager.defun
    def call(self, inp, tar, training):
        x = tf.concat([inp, tar], axis = 1)
        x = self.down1(x, training = training)
        x = self.down2(x, training = training)
        x = self.down3(x, training = training)
        
        x = self.zero_pad1(x)
        x = self.conv(x)
        x = self.batchnorm1(x, training = training)
        x = tf.nn.leaky_relu(x)
        
        x = self.zero_pad2(x)
        
        x = self.last(x)
        
        return x
generator = Generator()
discriminator = Discriminator()
LAMBDA = 100
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output),
                                               logits = disc_real_output)
    
    generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels= tf.zeros_like(disc_generated_output),
                                                    logits = disc_generated_output)
    
    total_disc_loss = real_loss + generated_loss
    
    return  total_disc_loss
def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_generated_output),
                                              logits = disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    
    return total_gen_loss
generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1 = 0.5)
discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1 = 0.5)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer,
                                discriminator_optimizer = discriminator_optimizer,
                                generator = generator,
                                discriminator = discriminator)
EPOCHS = 200
def generated_images(model, test_input, tar):
    predictions = model(test_input, training = True)
    plt.figure(figsize=(15,15))
    
    display_list = [test_input[0], tar[0], predictions[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    
    for i in range(3):
        plt.subplot(1,3, i+1)
        plt.title(title[i])
        
        plt.imshow(display_list[i] *0.5 +0.5)
        plt.axis('off')
    
    plt.show()
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for input_image, target in dataset:

            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                gen_output = generator(input_image, training = True)

                disc_real_output = discriminator(input_image, target, training = True)
                disc_generated_output = discriminator(input_image, gen_output, training = True)

                gen_loss = generator_loss(disc_generated_output, gen_output, target)
                disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

            generator_gradients = gen_tape.gradient(gen_loss,
                                                   generator.variables)
            discriminator_gradients = disc_tape.gradient(disc_loss,
                                                        discriminator.variables)

            generator_optimizer.apply_gradients(zip(generator_gradients, generator.variables))
            discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.variables))

        if epoch %1 == 0:
            clear_output(wait=True)
            for inp, tar in test_dataset.take(1):
                generated_images(generator, inp, tar)

        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print("Time taken for epoch {} is {} sec\n".format(epoch + 1,
                                                          time.time() -start))
train(train_dataset, EPOCHS)
output_18_0.png
Time taken for epoch 200 is 56.112149238586426 sec
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for inp, tar in test_dataset:
    generated_images(generator, inp, tar)

你可能感兴趣的:(pix2pix tensorflow personal)