SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

本篇blog的内容基于原始论文SRAGN-Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network(CVPR2017)和《生成对抗网络入门指南》第六章。完整代码及简析见文章末尾


一、 摘要:为什么要使用SRGAN

使用更深和更快的CNN已经对超分辨率(super-resolution)提升效果很好了,但是对图片上采样时候,应该怎么样提升精度?在本篇论文中,使用了GAN用于处理图像超精度SR。

这是第一个对放大四倍自然图像做超分辨率的框架。为了实现这个框架,作者改进了目标函数,使用RestNET来修复训练。

  1. adversarial loss由判别器训练原始图像和超精度图像的差异,使我们生成的图像更加接近自然图像。
  2. content loss由图像的视觉相似性生成,而不是像素空间的相似性。
  3. ResNET可以从下采样的图像恢复逼真的纹理。
  4. mean-opinion-score(MOS)测试作为图像效果的评判,最后的测试结果表明采用SRGAN获得的图像的MOS值比采用其他顶级的方法获得的图像的MOS值更加接近原始的高分辨图像。

 

二、 超分辨率SR的研究

超分辨率(SR)指的是由低分辨(LR)图像生成高分辨(HR)图像的技术。

目前被大多人采用的以最优化目标函数为基础的监督SR算法存在缺失图像高频纹理细节的问题,使生成的图像很模糊。这种算法大多以均方误差(MSE)为目标函数进行优化,在减小均方误差的同时又可以增大信噪比(PSNR)

但是MSE和PSNR值的高低并不能很好的表示视觉效果的好坏,PSNR最高也不能反映SR效果最好。 
 

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第1张图片

在本篇论文中,提出SRGAN,使用ResNET来作为优化目标网络。与以前的研究不同的是,我们定义了一个全新的perceptual loss使用了VGGNet的高级特征图结构,然后结合判别器来判断高精度图片。下面是对4x上采样高精度的例子:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第2张图片

 

三、 SRGAN结构

1. 实验目标:训练一个function G能够对给定的一个低精度LR的输入图像生成高精度HR对抗图像。

 

2. 结构

①生成器:在生成器使用一个前向反馈的CNN,对于训练数据采取SR-specific loss,并对生成器的参数\hat \theta_G进行优化:

这里 I^{HR} 是高精度训练图像,I^{LR} 是 I^{HR} 的低精度版本(下采样),\hat \theta_G 是生成器参数, l^{SR} 是损失函数见下面目标函数。

 

在前馈网络中,使用ResNet的结构来训练输入的LR图像。

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第3张图片

 

 

②判别器:根据原始GAN,这里我们同样做一个极小极大值函数。

这里 I^{HR} 是高精度训练图像,I^{LR} 是 I^{HR} 的低精度版本(下采样)。

 

对于真实的HR图像和生成的SR样本训练判别器使用LeakyReLU,不使用最大池化操作。包含一个VGG19的网络。 

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第4张图片

③目标函数:这里的 l_{SR} 是perceprtual loss fucntion,作为评估生成图像好坏的指标。

  • Content loss

Pixel-wise MSE loss
 

这里经常被作为优化目标使用在state-of-art项目的SR图像上。这里MSE的优化问题经常确实高频率的内容,所以经常会不满足处理平滑的纹理图像。

这里我们使用一个预训练的19层VGGNet(使用LeakyReLU,不使用最大池化操作):

这里 I^{HR} 是高精度训练图像,I^{LR} 是 I^{HR} 的低精度版本(下采样),W_{i,j},H_{i,j}是VGGNet的维度,\phi _{i,j} 指代在包含第j层CNN经过激活后,在第i层最大池化层之前的VGG19Net。

 

  • Adversarail loss(GAN loss)

这里是常规的判别器对于生成图像的判别损失

 

在后面许多论文中都采用了以上的损失结构,特别是在GAN与艺术生成里面,content loss极为常见。

 

四、实验评估

MOS testing

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第5张图片

 

五、实验代码

数据集地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/

1. 导入包及创建初始化超参数

import tensorflow as tf

import vgg19

import sys

sys.path.append('../')
import tfutil as t

tf.set_random_seed(777)  # reproducibility


class SRGAN:

    def __init__(self, s, batch_size=16, height=384, width=384, channel=3,
                 sample_num=1 * 1, sample_size=1,
                 df_dim=64, gf_dim=64, lr=1e-4, use_vgg19=True):

        """ Super-Resolution GAN Class
        # General Settings
        :param s: TF Session
        :param batch_size: training batch size, default 16
        :param height: input image height, default 384
        :param width: input image width, default 384
        :param channel: input image channel, default 3 (RGB)
        - in case of DIV2K-HR, image size is 384x384x3(HWC).

        # Output Settings
        :param sample_num: the number of output images, default 1
        :param sample_size: sample image size, default 1

        # For CNN model
        :param df_dim: discriminator filter, default 64
        :param gf_dim: generator filter, default 64

        # Training Option
        :param lr: learning rate, default 1e-4
        :param use_vgg19: using pre-trained vgg19 bottle-neck features, default False
        """

        self.s = s
        self.batch_size = batch_size

        self.height = height
        self.width = width
        self.channel = channel

        self.lr_image_shape = [None, self.height // 4, self.width // 4, self.channel]
        self.hr_image_shape = [None, self.height, self.width, self.channel]

        self.vgg_image_shape = [224, 224, 3]

        self.sample_num = sample_num
        self.sample_size = sample_size

        self.df_dim = df_dim
        self.gf_dim = gf_dim

        self.beta1 = 0.9
        self.beta2 = 0.999

        self.lr_decay_rate = 1e-1
        self.lr_low_boundary = 1e-5
        self.lr_update_step = 1e5
        self.lr_update_epoch = 1000

        self.vgg_mean = [103.939, 116.779, 123.68]

        # pre-defined
        self.d_real = 0.
        self.d_fake = 0.
        self.d_loss = 0.
        self.g_adv_loss = 0.
        self.g_cnt_loss = 0.
        self.g_loss = 0.
        self.psnr = 0.

        self.use_vgg19 = use_vgg19
        self.vgg19 = None

        self.g = None

        self.adv_scaling = 1e-3
        self.cnt_scaling = 1. / 12.75  # 6e-3

        self.d_op = None
        self.g_op = None
        self.g_init_op = None

        self.merged = None
        self.writer = None
        self.saver = None

        # Placeholders
        self.x_hr = tf.placeholder(tf.float32, shape=self.hr_image_shape, name="x-image-hr")  # (-1, 384, 384, 3)
        self.x_lr = tf.placeholder(tf.float32, shape=self.lr_image_shape, name="x-image-lr")  # (-1, 96, 96, 3)

        self.lr = tf.placeholder(tf.float32, name='lr')

        self.build_srgan()  # build SRGAN model

 

2. 构造生成器和判别器

①判别器:使用LeakyReLU,

    def discriminator(self, x, reuse=None):
        """
        # Following a network architecture referred in the paper
        :param x: Input images (-1, 384, 384, 3)
        :param reuse: re-usability
        :return: HR (High Resolution) or SR (Super Resolution) images
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            x = t.conv2d(x, self.df_dim, 3, 1, name='n64s1-1')
            x = tf.nn.leaky_relu(x)

            strides = [2, 1]
            filters = [1, 2, 2, 4, 4, 8, 8]

            for i, f in enumerate(filters):
                x = t.conv2d(x, f=f, k=3, s=strides[i % 2], name='n%ds%d-%d' % (f, strides[i % 2], i + 1))
                x = t.batch_norm(x, name='n%d-bn-%d' % (f, i + 1))
                x = tf.nn.leaky_relu(x)

            x = tf.layers.flatten(x)  # (-1, 96 * 96 * 64)

            x = t.dense(x, 1024, name='disc-fc-1')
            x = tf.nn.leaky_relu(x)

            x = t.dense(x, 1, name='disc-fc-2')
            # x = tf.nn.sigmoid(x)
            return x

②生成器

    def generator(self, x, reuse=None, is_train=True):
        """
        :param x: LR (Low Resolution) images, (-1, 96, 96, 3)
        :param reuse: scope re-usability
        :param is_train: is trainable, default True
        :return: SR (Super Resolution) images, (-1, 384, 384, 3)
        """

        with tf.variable_scope("generator", reuse=reuse):
            def residual_block(x, f, name="", _is_train=True):
                with tf.variable_scope(name):
                    shortcut = tf.identity(x, name='n64s1-shortcut')

                    x = t.conv2d(x, f, 3, 1, name="n64s1-1")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-1")
                    x = t.prelu(x, reuse=reuse, name='n64s1-prelu-1')
                    x = t.conv2d(x, f, 3, 1, name="n64s1-2")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-2")
                    x = tf.add(x, shortcut)

                    return x

            x = t.conv2d(x, self.gf_dim, 9, 1, name='n64s1-1')
            x = t.prelu(x, name='n64s1-prelu-1')

            skip_conn = tf.identity(x, name='skip_connection')

            # B residual blocks
            for i in range(1, 17):  # (1, 9)
                x = residual_block(x, self.gf_dim, name='b-residual_block_%d' % i, _is_train=is_train)

            x = t.conv2d(x, self.gf_dim, 3, 1, name='n64s1-3')
            x = t.batch_norm(x, is_train=is_train, name='n64s1-bn-3')

            x = tf.add(x, skip_conn)

            # sub-pixel conv2d blocks
            for i in range(1, 3):
                x = t.conv2d(x, self.gf_dim * 4, 3, 1, name='n256s1-%d' % (i + 2))
                x = t.sub_pixel_conv2d(x, f=None, s=2)
                x = t.prelu(x, name='n256s1-prelu-%d' % i)

            x = t.conv2d(x, self.channel, 9, 1, name='n3s1')  # (-1, 384, 384, 3)
            x = tf.nn.tanh(x)
            return x

 

3. 构造VGGNet

    def build_vgg19(self, x, reuse=None):
        with tf.variable_scope("vgg19", reuse=reuse):
            # image re-scaling
            x = tf.cast((x + 1) / 2, dtype=tf.float32)  # [-1, 1] to [0, 1]
            x = tf.cast(x * 255., dtype=tf.float32)     # [0, 1]  to [0, 255]

            r, g, b = tf.split(x, 3, 3)
            bgr = tf.concat([b - self.vgg_mean[0],
                             g - self.vgg_mean[1],
                             r - self.vgg_mean[2]], axis=3)

            self.vgg19 = vgg19.VGG19(bgr)

            net = self.vgg19.vgg19_net['conv5_4']

            return net  # last layer

 

4. 构造SRGAN模型

    def build_srgan(self):
        # Generator
        self.g = self.generator(self.x_lr)

        # Discriminator
        d_real = self.discriminator(self.x_hr)
        d_fake = self.discriminator(self.g, reuse=True)

        # Losses
        # d_real_loss = -tf.reduce_mean(t.safe_log(d_real))
        # d_fake_loss = -tf.reduce_mean(t.safe_log(1. - d_fake))
        d_real_loss = t.sce_loss(d_real, tf.ones_like(d_real))
        d_fake_loss = t.sce_loss(d_fake, tf.zeros_like(d_fake))
        self.d_loss = d_real_loss + d_fake_loss

        if self.use_vgg19:
            x_vgg_real = tf.image.resize_images(self.x_hr, size=self.vgg_image_shape[:2], align_corners=False)
            x_vgg_fake = tf.image.resize_images(self.g, size=self.vgg_image_shape[:2], align_corners=False)

            vgg_bottle_real = self.build_vgg19(x_vgg_real)
            vgg_bottle_fake = self.build_vgg19(x_vgg_fake, reuse=True)

            self.g_cnt_loss = self.cnt_scaling * t.mse_loss(vgg_bottle_fake, vgg_bottle_real, self.batch_size,
                                                            is_mean=True)
        else:
            self.g_cnt_loss = t.mse_loss(self.g, self.x_hr, self.batch_size, is_mean=True)

        # self.g_adv_loss = self.adv_scaling * tf.reduce_mean(-1. * t.safe_log(d_fake))
        self.g_adv_loss = self.adv_scaling * t.sce_loss(d_fake, tf.ones_like(d_fake))
        self.g_loss = self.g_adv_loss + self.g_cnt_loss

        def inverse_transform(img):
            return (img + 1.) * 127.5

        # calculate PSNR
        g, x_hr = inverse_transform(self.g), inverse_transform(self.x_hr)
        self.psnr = t.psnr_loss(g, x_hr, self.batch_size)

        # Summary
        tf.summary.scalar("loss/d_real_loss", d_real_loss)
        tf.summary.scalar("loss/d_fake_loss", d_fake_loss)
        tf.summary.scalar("loss/d_loss", self.d_loss)
        tf.summary.scalar("loss/g_cnt_loss", self.g_cnt_loss)
        tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
        tf.summary.scalar("loss/g_loss", self.g_loss)
        tf.summary.scalar("misc/psnr", self.psnr)
        tf.summary.scalar("misc/lr", self.lr)

        # Optimizer
        t_vars = tf.trainable_variables()
        d_params = [v for v in t_vars if v.name.startswith('d')]
        g_params = [v for v in t_vars if v.name.startswith('g')]

        self.d_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.d_loss,
                                                                                        var_list=d_params)
        self.g_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_loss,
                                                                                        var_list=g_params)

        # pre-train
        self.g_init_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                                beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_cnt_loss,
                                                                                             var_list=g_params)

        # Merge summary
        self.merged = tf.summary.merge_all()

        # Model saver
        self.saver = tf.train.Saver(max_to_keep=2)
        self.writer = tf.summary.FileWriter('./model/', self.s.graph)

 

5. 主函数

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np

import sys
import time

sys.path.append('../')
import image_utils as iu
from datasets import Div2KDataSet as DataSet


np.random.seed(1337)


results = {
    'output': './gen_img/',
    'model': './model/SRGAN-model.ckpt'
}

train_step = {
    'batch_size': 16,
    'init_epochs': 100,
    'train_epochs': 1501,
    'global_step': 200001,
    'logging_interval': 100,
}


def main():
    start_time = time.time()  # Clocking start

    # Div2K - Track 1: Bicubic downscaling - x4 DataSet load
    """
    ds = DataSet(ds_path="/home/zero/hdd/DataSet/DIV2K/",
                 ds_name="X4",
                 use_save=True,
                 save_type="to_h5",
                 save_file_name="/home/zero/hdd/DataSet/DIV2K/DIV2K",
                 use_img_scale=True)
    """
    ds = DataSet(ds_hr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-hr.h5",
                 ds_lr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-lr.h5",
                 use_img_scale=True)

    hr, lr = ds.hr_images, ds.lr_images

    print("[+] Loaded HR image ", hr.shape)
    print("[+] Loaded LR image ", lr.shape)

    # GPU configure
    gpu_config = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config)

    with tf.Session(config=config) as s:
        with tf.device("/gpu:1"):  # Change
            # SRGAN Model
            model = SRGAN(s, batch_size=train_step['batch_size'],
                                use_vgg19=False)

        # Initializing
        s.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        ckpt = tf.train.get_checkpoint_state('./model/')
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            model.saver.restore(s, ckpt.model_checkpoint_path)

            global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("[+] global step : %d" % global_step, " successfully loaded")
        else:
            global_step = 0
            print('[-] No checkpoint file found')

        start_epoch = global_step // (ds.n_images // train_step['batch_size'])

        rnd = np.random.randint(0, ds.n_images)
        sample_x_hr, sample_x_lr = hr[rnd], lr[rnd]

        sample_x_hr, sample_x_lr = \
            np.reshape(sample_x_hr, [1] + model.hr_image_shape[1:]), \
            np.reshape(sample_x_lr, [1] + model.lr_image_shape[1:])

        # Export real image
        # valid_image_height = model.sample_size
        # valid_image_width = model.sample_size
        sample_hr_dir, sample_lr_dir = results['output'] + 'valid_hr.png', results['output'] + 'valid_lr.png'

        # Generated image save
        iu.save_images(sample_x_hr,
                       size=[1, 1],
                       image_path=sample_hr_dir,
                       inv_type='127')

        iu.save_images(sample_x_lr,
                       size=[1, 1],
                       image_path=sample_lr_dir,
                       inv_type='127')

        learning_rate = 1e-4
        for epoch in range(start_epoch, train_step['train_epochs']):
            pointer = 0
            for i in range(ds.n_images // train_step['batch_size']):
                start = pointer
                pointer += train_step['batch_size']

                if pointer > ds.n_images:  # if 1 epoch is ended
                    # Shuffle training DataSet
                    perm = np.arange(ds.n_images)
                    np.random.shuffle(perm)

                    hr, lr = hr[perm], lr[perm]

                    start = 0
                    pointer = train_step['batch_size']

                end = pointer

                batch_x_hr, batch_x_lr = hr[start:end], lr[start:end]

                # reshape
                batch_x_hr = np.reshape(batch_x_hr, [train_step['batch_size']] + model.hr_image_shape[1:])
                batch_x_lr = np.reshape(batch_x_lr, [train_step['batch_size']] + model.lr_image_shape[1:])

                # Update Only G network
                d_loss, g_loss, g_init_loss = 0., 0., 0.
                if epoch <= train_step['init_epochs']:
                    _, g_init_loss = s.run([model.g_init_op, model.g_cnt_loss],
                                           feed_dict={
                                               model.x_hr: batch_x_hr,
                                               model.x_lr: batch_x_lr,
                                               model.lr: learning_rate,
                                           })
                # Update G/D network
                else:
                    _, d_loss = s.run([model.d_op, model.d_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                    _, g_loss = s.run([model.g_op, model.g_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                if i % train_step['logging_interval'] == 0:
                    # Print loss
                    if epoch <= train_step['init_epochs']:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " MSE loss : {:.8f}".format(g_init_loss))
                    else:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " D loss : {:.8f}".format(d_loss),
                              " G loss : {:.8f}".format(g_loss))

                        summary = s.run(model.merged,
                                        feed_dict={
                                            model.x_hr: batch_x_hr,
                                            model.x_lr: batch_x_lr,
                                            model.lr: learning_rate,
                                        })

                        # Summary saver
                        model.writer.add_summary(summary, global_step)

                    # Training G model with sample image and noise
                    sample_x_lr = np.reshape(sample_x_lr, [model.sample_num] + model.lr_image_shape[1:])
                    samples = s.run(model.g,
                                    feed_dict={
                                        model.x_lr: sample_x_lr,
                                        model.lr: learning_rate,
                                    })

                    # Export image generated by model G
                    # sample_image_height = model.output_height
                    # sample_image_width = model.output_width
                    sample_dir = results['output'] + 'train_{:08d}.png'.format(global_step)

                    # Generated image save
                    iu.save_images(samples,
                                   size=[1, 1],
                                   image_path=sample_dir,
                                   inv_type='127')

                    # Model save
                    model.saver.save(s, results['model'], global_step)

                # Learning Rate update
                if epoch and epoch % model.lr_update_epoch == 0:
                    learning_rate *= model.lr_decay_rate
                    learning_rate = max(learning_rate, model.lr_low_boundary)

                global_step += 1

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))

    # Close tf.Session
    s.close()


if __name__ == '__main__':
    main()

 

6. 运行结果(生成图像)

初始图像LR:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第6张图片

训练过程图像:0-55000

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第7张图片SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第8张图片SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第9张图片SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第10张图片

生成高精度HR图像:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_第11张图片

 

完整代码

import tensorflow as tf

import vgg19

import sys

sys.path.append('../')
import tfutil as t


tf.set_random_seed(777)  # reproducibility


class SRGAN:

    def __init__(self, s, batch_size=16, height=384, width=384, channel=3,
                 sample_num=1 * 1, sample_size=1,
                 df_dim=64, gf_dim=64, lr=1e-4, use_vgg19=True):

        """ Super-Resolution GAN Class
        # General Settings
        :param s: TF Session
        :param batch_size: training batch size, default 16
        :param height: input image height, default 384
        :param width: input image width, default 384
        :param channel: input image channel, default 3 (RGB)
        - in case of DIV2K-HR, image size is 384x384x3(HWC).

        # Output Settings
        :param sample_num: the number of output images, default 1
        :param sample_size: sample image size, default 1

        # For CNN model
        :param df_dim: discriminator filter, default 64
        :param gf_dim: generator filter, default 64

        # Training Option
        :param lr: learning rate, default 1e-4
        :param use_vgg19: using pre-trained vgg19 bottle-neck features, default False
        """

        self.s = s
        self.batch_size = batch_size

        self.height = height
        self.width = width
        self.channel = channel

        self.lr_image_shape = [None, self.height // 4, self.width // 4, self.channel]
        self.hr_image_shape = [None, self.height, self.width, self.channel]

        self.vgg_image_shape = [224, 224, 3]

        self.sample_num = sample_num
        self.sample_size = sample_size

        self.df_dim = df_dim
        self.gf_dim = gf_dim

        self.beta1 = 0.9
        self.beta2 = 0.999

        self.lr_decay_rate = 1e-1
        self.lr_low_boundary = 1e-5
        self.lr_update_step = 1e5
        self.lr_update_epoch = 1000

        self.vgg_mean = [103.939, 116.779, 123.68]

        # pre-defined
        self.d_real = 0.
        self.d_fake = 0.
        self.d_loss = 0.
        self.g_adv_loss = 0.
        self.g_cnt_loss = 0.
        self.g_loss = 0.
        self.psnr = 0.

        self.use_vgg19 = use_vgg19
        self.vgg19 = None

        self.g = None

        self.adv_scaling = 1e-3
        self.cnt_scaling = 1. / 12.75  # 6e-3

        self.d_op = None
        self.g_op = None
        self.g_init_op = None

        self.merged = None
        self.writer = None
        self.saver = None

        # Placeholders
        self.x_hr = tf.placeholder(tf.float32, shape=self.hr_image_shape, name="x-image-hr")  # (-1, 384, 384, 3)
        self.x_lr = tf.placeholder(tf.float32, shape=self.lr_image_shape, name="x-image-lr")  # (-1, 96, 96, 3)

        self.lr = tf.placeholder(tf.float32, name='lr')

        self.build_srgan()  # build SRGAN model

    def discriminator(self, x, reuse=None):
        """
        # Following a network architecture referred in the paper
        :param x: Input images (-1, 384, 384, 3)
        :param reuse: re-usability
        :return: HR (High Resolution) or SR (Super Resolution) images
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            x = t.conv2d(x, self.df_dim, 3, 1, name='n64s1-1')
            x = tf.nn.leaky_relu(x)

            strides = [2, 1]
            filters = [1, 2, 2, 4, 4, 8, 8]

            for i, f in enumerate(filters):
                x = t.conv2d(x, f=f, k=3, s=strides[i % 2], name='n%ds%d-%d' % (f, strides[i % 2], i + 1))
                x = t.batch_norm(x, name='n%d-bn-%d' % (f, i + 1))
                x = tf.nn.leaky_relu(x)

            x = tf.layers.flatten(x)  # (-1, 96 * 96 * 64)

            x = t.dense(x, 1024, name='disc-fc-1')
            x = tf.nn.leaky_relu(x)

            x = t.dense(x, 1, name='disc-fc-2')
            # x = tf.nn.sigmoid(x)
            return x

    def generator(self, x, reuse=None, is_train=True):
        """
        :param x: LR (Low Resolution) images, (-1, 96, 96, 3)
        :param reuse: scope re-usability
        :param is_train: is trainable, default True
        :return: SR (Super Resolution) images, (-1, 384, 384, 3)
        """

        with tf.variable_scope("generator", reuse=reuse):
            def residual_block(x, f, name="", _is_train=True):
                with tf.variable_scope(name):
                    shortcut = tf.identity(x, name='n64s1-shortcut')

                    x = t.conv2d(x, f, 3, 1, name="n64s1-1")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-1")
                    x = t.prelu(x, reuse=reuse, name='n64s1-prelu-1')
                    x = t.conv2d(x, f, 3, 1, name="n64s1-2")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-2")
                    x = tf.add(x, shortcut)

                    return x

            x = t.conv2d(x, self.gf_dim, 9, 1, name='n64s1-1')
            x = t.prelu(x, name='n64s1-prelu-1')

            skip_conn = tf.identity(x, name='skip_connection')

            # B residual blocks
            for i in range(1, 17):  # (1, 9)
                x = residual_block(x, self.gf_dim, name='b-residual_block_%d' % i, _is_train=is_train)

            x = t.conv2d(x, self.gf_dim, 3, 1, name='n64s1-3')
            x = t.batch_norm(x, is_train=is_train, name='n64s1-bn-3')

            x = tf.add(x, skip_conn)

            # sub-pixel conv2d blocks
            for i in range(1, 3):
                x = t.conv2d(x, self.gf_dim * 4, 3, 1, name='n256s1-%d' % (i + 2))
                x = t.sub_pixel_conv2d(x, f=None, s=2)
                x = t.prelu(x, name='n256s1-prelu-%d' % i)

            x = t.conv2d(x, self.channel, 9, 1, name='n3s1')  # (-1, 384, 384, 3)
            x = tf.nn.tanh(x)
            return x

    def build_vgg19(self, x, reuse=None):
        with tf.variable_scope("vgg19", reuse=reuse):
            # image re-scaling
            x = tf.cast((x + 1) / 2, dtype=tf.float32)  # [-1, 1] to [0, 1]
            x = tf.cast(x * 255., dtype=tf.float32)     # [0, 1]  to [0, 255]

            r, g, b = tf.split(x, 3, 3)
            bgr = tf.concat([b - self.vgg_mean[0],
                             g - self.vgg_mean[1],
                             r - self.vgg_mean[2]], axis=3)

            self.vgg19 = vgg19.VGG19(bgr)

            net = self.vgg19.vgg19_net['conv5_4']

            return net  # last layer

    def build_srgan(self):
        # Generator
        self.g = self.generator(self.x_lr)

        # Discriminator
        d_real = self.discriminator(self.x_hr)
        d_fake = self.discriminator(self.g, reuse=True)

        # Losses
        # d_real_loss = -tf.reduce_mean(t.safe_log(d_real))
        # d_fake_loss = -tf.reduce_mean(t.safe_log(1. - d_fake))
        d_real_loss = t.sce_loss(d_real, tf.ones_like(d_real))
        d_fake_loss = t.sce_loss(d_fake, tf.zeros_like(d_fake))
        self.d_loss = d_real_loss + d_fake_loss

        if self.use_vgg19:
            x_vgg_real = tf.image.resize_images(self.x_hr, size=self.vgg_image_shape[:2], align_corners=False)
            x_vgg_fake = tf.image.resize_images(self.g, size=self.vgg_image_shape[:2], align_corners=False)

            vgg_bottle_real = self.build_vgg19(x_vgg_real)
            vgg_bottle_fake = self.build_vgg19(x_vgg_fake, reuse=True)

            self.g_cnt_loss = self.cnt_scaling * t.mse_loss(vgg_bottle_fake, vgg_bottle_real, self.batch_size,
                                                            is_mean=True)
        else:
            self.g_cnt_loss = t.mse_loss(self.g, self.x_hr, self.batch_size, is_mean=True)

        # self.g_adv_loss = self.adv_scaling * tf.reduce_mean(-1. * t.safe_log(d_fake))
        self.g_adv_loss = self.adv_scaling * t.sce_loss(d_fake, tf.ones_like(d_fake))
        self.g_loss = self.g_adv_loss + self.g_cnt_loss

        def inverse_transform(img):
            return (img + 1.) * 127.5

        # calculate PSNR
        g, x_hr = inverse_transform(self.g), inverse_transform(self.x_hr)
        self.psnr = t.psnr_loss(g, x_hr, self.batch_size)

        # Summary
        tf.summary.scalar("loss/d_real_loss", d_real_loss)
        tf.summary.scalar("loss/d_fake_loss", d_fake_loss)
        tf.summary.scalar("loss/d_loss", self.d_loss)
        tf.summary.scalar("loss/g_cnt_loss", self.g_cnt_loss)
        tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
        tf.summary.scalar("loss/g_loss", self.g_loss)
        tf.summary.scalar("misc/psnr", self.psnr)
        tf.summary.scalar("misc/lr", self.lr)

        # Optimizer
        t_vars = tf.trainable_variables()
        d_params = [v for v in t_vars if v.name.startswith('d')]
        g_params = [v for v in t_vars if v.name.startswith('g')]

        self.d_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.d_loss,
                                                                                        var_list=d_params)
        self.g_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_loss,
                                                                                        var_list=g_params)

        # pre-train
        self.g_init_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                                beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_cnt_loss,
                                                                                             var_list=g_params)

        # Merge summary
        self.merged = tf.summary.merge_all()

        # Model saver
        self.saver = tf.train.Saver(max_to_keep=2)
        self.writer = tf.summary.FileWriter('./model/', self.s.graph)
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np

import sys
import time

sys.path.append('../')
import image_utils as iu
from datasets import Div2KDataSet as DataSet


np.random.seed(1337)


results = {
    'output': './gen_img/',
    'model': './model/SRGAN-model.ckpt'
}

train_step = {
    'batch_size': 16,
    'init_epochs': 100,
    'train_epochs': 1501,
    'global_step': 200001,
    'logging_interval': 100,
}


def main():
    start_time = time.time()  # Clocking start

    # Div2K - Track 1: Bicubic downscaling - x4 DataSet load
    """
    ds = DataSet(ds_path="/home/zero/hdd/DataSet/DIV2K/",
                 ds_name="X4",
                 use_save=True,
                 save_type="to_h5",
                 save_file_name="/home/zero/hdd/DataSet/DIV2K/DIV2K",
                 use_img_scale=True)
    """
    ds = DataSet(ds_hr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-hr.h5",
                 ds_lr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-lr.h5",
                 use_img_scale=True)

    hr, lr = ds.hr_images, ds.lr_images

    print("[+] Loaded HR image ", hr.shape)
    print("[+] Loaded LR image ", lr.shape)

    # GPU configure
    gpu_config = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config)

    with tf.Session(config=config) as s:
        with tf.device("/gpu:1"):  # Change
            # SRGAN Model
            model = SRGAN(s, batch_size=train_step['batch_size'],
                                use_vgg19=False)

        # Initializing
        s.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        ckpt = tf.train.get_checkpoint_state('./model/')
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            model.saver.restore(s, ckpt.model_checkpoint_path)

            global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("[+] global step : %d" % global_step, " successfully loaded")
        else:
            global_step = 0
            print('[-] No checkpoint file found')

        start_epoch = global_step // (ds.n_images // train_step['batch_size'])

        rnd = np.random.randint(0, ds.n_images)
        sample_x_hr, sample_x_lr = hr[rnd], lr[rnd]

        sample_x_hr, sample_x_lr = \
            np.reshape(sample_x_hr, [1] + model.hr_image_shape[1:]), \
            np.reshape(sample_x_lr, [1] + model.lr_image_shape[1:])

        # Export real image
        # valid_image_height = model.sample_size
        # valid_image_width = model.sample_size
        sample_hr_dir, sample_lr_dir = results['output'] + 'valid_hr.png', results['output'] + 'valid_lr.png'

        # Generated image save
        iu.save_images(sample_x_hr,
                       size=[1, 1],
                       image_path=sample_hr_dir,
                       inv_type='127')

        iu.save_images(sample_x_lr,
                       size=[1, 1],
                       image_path=sample_lr_dir,
                       inv_type='127')

        learning_rate = 1e-4
        for epoch in range(start_epoch, train_step['train_epochs']):
            pointer = 0
            for i in range(ds.n_images // train_step['batch_size']):
                start = pointer
                pointer += train_step['batch_size']

                if pointer > ds.n_images:  # if 1 epoch is ended
                    # Shuffle training DataSet
                    perm = np.arange(ds.n_images)
                    np.random.shuffle(perm)

                    hr, lr = hr[perm], lr[perm]

                    start = 0
                    pointer = train_step['batch_size']

                end = pointer

                batch_x_hr, batch_x_lr = hr[start:end], lr[start:end]

                # reshape
                batch_x_hr = np.reshape(batch_x_hr, [train_step['batch_size']] + model.hr_image_shape[1:])
                batch_x_lr = np.reshape(batch_x_lr, [train_step['batch_size']] + model.lr_image_shape[1:])

                # Update Only G network
                d_loss, g_loss, g_init_loss = 0., 0., 0.
                if epoch <= train_step['init_epochs']:
                    _, g_init_loss = s.run([model.g_init_op, model.g_cnt_loss],
                                           feed_dict={
                                               model.x_hr: batch_x_hr,
                                               model.x_lr: batch_x_lr,
                                               model.lr: learning_rate,
                                           })
                # Update G/D network
                else:
                    _, d_loss = s.run([model.d_op, model.d_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                    _, g_loss = s.run([model.g_op, model.g_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                if i % train_step['logging_interval'] == 0:
                    # Print loss
                    if epoch <= train_step['init_epochs']:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " MSE loss : {:.8f}".format(g_init_loss))
                    else:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " D loss : {:.8f}".format(d_loss),
                              " G loss : {:.8f}".format(g_loss))

                        summary = s.run(model.merged,
                                        feed_dict={
                                            model.x_hr: batch_x_hr,
                                            model.x_lr: batch_x_lr,
                                            model.lr: learning_rate,
                                        })

                        # Summary saver
                        model.writer.add_summary(summary, global_step)

                    # Training G model with sample image and noise
                    sample_x_lr = np.reshape(sample_x_lr, [model.sample_num] + model.lr_image_shape[1:])
                    samples = s.run(model.g,
                                    feed_dict={
                                        model.x_lr: sample_x_lr,
                                        model.lr: learning_rate,
                                    })

                    # Export image generated by model G
                    # sample_image_height = model.output_height
                    # sample_image_width = model.output_width
                    sample_dir = results['output'] + 'train_{:08d}.png'.format(global_step)

                    # Generated image save
                    iu.save_images(samples,
                                   size=[1, 1],
                                   image_path=sample_dir,
                                   inv_type='127')

                    # Model save
                    model.saver.save(s, results['model'], global_step)

                # Learning Rate update
                if epoch and epoch % model.lr_update_epoch == 0:
                    learning_rate *= model.lr_decay_rate
                    learning_rate = max(learning_rate, model.lr_low_boundary)

                global_step += 1

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))

    # Close tf.Session
    s.close()


if __name__ == '__main__':
    main()

 

你可能感兴趣的:(GAN,深度学习paper,Adversarial,Network,paper_GAN)