《Globally and locally consistent image completion》论文复现

以下代码是论文《Globally and locally consistent image completion》的代码实现,论文地址:http://xueshu.baidu.com/usercenter/paper/show?paperid=ea74830570062151f14abfb1fe89bb33&site=xueshu_se&hitarticle=1
论文速读可参考我的另一篇文章:https://www.jianshu.com/p/12da271c8bf8
使用的数据集是CelebA人脸数据集,数据集下载地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
框架:tensorflow 1.11.0
图像大小:128 * 128 | 原文是 256 * 256
原文作者使用4个K80 GPU,训练了2个月才训练完成,我这里暂时没有什么硬件资源,根本跑不动,所以下面的代码虽然可以跑,但一些参数只是根据经验设的,并没有验证其效果,请谨慎食用。

因为已经写过论文的研究思路,下面不再对代码实现思路进行讲解,如果有不懂的请参考上面提到的论文和相应讲解文章,代码的尽可能的详细的注释了,相信大家看起来难度不大。

make_data.py

"""
说明:数据预处理,将图片读取到npy文件中,这样就可避免每次都去读一个一个的图片数据,可以加快读取数据的速度
npy文件——Numpy专用的二进制格式
数据集地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
import glob
import cv2
import numpy as np


image_size = 128
train_ratio = 0.8

# 得到所有图片的路径,数据文件不要放在工程目录下,否则编辑工程时可能会比较卡
paths = glob.glob(r'D:\bigdata\img_align_celeba/*.jpg')
x = []  # 图片数据列表
# 读取图像,为了训练快点,只取1000个图像进程处理
for img_path in paths[:1000]:
    '''
    如果cv2没有提示,卸载重装
    卸载:pip uninstall opencv-python
    不使用缓存重装:pip --no-cache-dir install opencv-python -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
    安装扩展:pip --no-cache-dir install opencv-contrib-python -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
    '''
    img = cv2.imread(img_path)  # 得到每幅图片的矩阵表示,shape:(218, 178, 3)
    # 对图像进行缩放--插值法
    img = cv2.resize(img, (image_size, image_size))
    # 色彩空间的转化,以便生成mask图等操作
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    x.append(img)

x = np.array(x, dtype=np.uint8)  # 规定数据量类型np.uint8是为了节省存储空间
# 打乱图片排序
np.random.shuffle(x)

p = int(train_ratio * len(x))
x_train = x[:p]
x_test = x[p:]

np.save(r'D:\demos\image processing\demo\data_my\x_train.npy', x_train)
np.save(r'D:\demos\image processing\demo\data_my\x_test.npy', x_test)

network_build.py

"""
定义训练网络所需的各部分
"""
import tensorflow as tf


class Network:
    def __init__(self, x, mask, local_x, global_completion, local_completion, is_training, batch_size):
        """
        :param x: 输入
        :param mask: 需要填补修复的图像表示,缺失区域的值为1,其他部分为0,整体大小和完整图像相同
        :param local_x: 从原图中抠出来的那部分
        :param global_completion: 经过补全网络后的缺失部分之外的部分图像
        :param local_completion: 经过补全网络之后的补全的部分
        :param is_training:
        :param batch_size:
        """
        self.batch_size = batch_size
        # x * (1 - mask)可以实现将输入图像“挖洞”,洞的地方值全为0,所以generator输入的是一张带洞的图像
        self.imitation = self.generator(x * (1 - mask), is_training)
        # 补全图像应该是:填补的地方是网络生成的,但其他地方应该是原图数值
        self.completion = self.imitation * mask + x * (1 - mask)
        # 输入真实图像数据,需要用到的变量,自己去创建、更新
        self.real = self.discriminator(x, local_x, reuse=False)
        # 输入补全网络生成的图像数据,判别网络用到的变量应当是训练真实图像数据时创建的相同变量,所以reuse=True
        self.fake = self.discriminator(global_completion, local_completion, reuse=True)
        self.g_loss = self.calc_g_loss(x, self.completion)
        self.d_loss = self.calc_d_loss(self.real, self.fake)
        """
        tf.get_collection(key,scope=None)
        用来获取一个名称是‘key’的集合中的所有元素,返回的是一个列表,列表的顺序是按照变量放入集合中的先后;   scope参数可选,
        表示的是名称空间(名称域),如果指定,就返回名称域中所有放入‘key’的变量的列表,不指定则返回所有变量。
        tf.Optimizer默认只优化tf.GraphKeys.TRAINABLE_VARIABLES中的变量。
        """
        self.g_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        self.d_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

    def conv_layer(self, x, filter_shape, stride):
        filters = tf.get_variable(
            name='weight',
            shape=filter_shape,
            dtype=tf.float32,
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=True)
        # padding='SAME',输出图像大小是边长除以步长,向上取整,就是不够则填充
        # padding='VALID',输出图像大小是边长减去滤波器大小加一,最后除以步长,向上取整,不进行填充
        """
        tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)
        除去name参数用以指定该操作的name,与方法有关的一共五个参数:

        第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, in_width, in_channels]
        这样的shape,具体含义是[训练时一个batch的图片数量, 图片高度, 图片宽度, 图像通道数],注意这是一个4维的Tensor,
        要求类型为float32和float64其中之一

        第二个参数filter:相当于CNN中的卷积核,它要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]
        这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数,卷积核个数],要求类型与参数input相同,有一个地方需要注意,
        第三维in_channels,就是参数input的第四维
        
        第三个参数strides:卷积时在图像每一维的步长,这是一个一维的向量,长度4
        
        第四个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式
        
        第五个参数:use_cudnn_on_gpu:bool类型,是否使用cudnn加速,默认为true
        
        结果返回一个Tensor,这个输出,就是我们常说的feature map,shape仍然是[batch, height, width, channels]这种形式。
        """
        return tf.nn.conv2d(x, filters, [1, stride, stride, 1], padding='SAME')

    # 空洞卷积
    def dilated_conv_layer(self, x, filter_shape, dilation):
        filters = tf.get_variable(
            name='weight',
            shape=filter_shape,
            dtype=tf.float32,
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=True)
        """
        tf.nn.atrous_conv2d(value,filters,rate,padding,name=None)
        value: 指需要做卷积的输入图像,要求是一个4维Tensor,具有[batch, height, width, channels]
        filters: 相当于CNN中的卷积核,要求是一个4维Tensor,具有[filter_height, filter_width, channels, out_channels]
                  这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数或前一次卷积核个数,本次卷积核个数]
        rate: 即空洞率dilation,在卷积核中穿插补(rate-1)个0,rate=1时,就没有0插入,此时这个函数就变成了普通卷积。
        """
        return tf.nn.atrous_conv2d(x, filters, dilation, padding='SAME')

    # 反向卷积
    def deconv_layer(self, x, filter_shape, output_shape, stride):
        filters = tf.get_variable(
            name='weight',
            shape=filter_shape,
            dtype=tf.float32,
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=True)
        """
        tf.conv2d_transpose(value, filter, output_shape, strides, padding="SAME", data_format="NHWC", name=None)
        第一个参数value:指需要做反卷积的输入图像,它要求是一个Tensor
        第二个参数filter:卷积核,它要求是一个Tensor,具有[filter_height, filter_width, out_channels, in_channels]这样的shape,
        具体含义是[卷积核的高度,卷积核的宽度,卷积核个数,图像通道数或上次卷积核个数]
        第三个参数output_shape:反卷积操作输出的shape,普通卷积操作是没有这个参数的.
        第四个参数strides:反卷积时在图像每一维的步长,这是一个一维的向量,长度4
        第五个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式
        第六个参数data_format:string类型的量,'NHWC'和'NCHW'其中之一,这是tensorflow新版本中新加的参数,它说明了value参数的
        数据格式。'NHWC'指tensorflow标准的数据格式[batch, height, width, in_channels],'NCHW'指Theano的数据格式,
        [batch, in_channels,height, width],当然默认值是'NHWC'
        """
        return tf.nn.conv2d_transpose(x, filters, output_shape, [1, stride, stride, 1])

    def batch_normalize(self, x, is_training, decay=0.99, epsilon=0.001):
        """
        tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)是一个低级的操作函数,
        调用者需要自己处理张量的平均值和方差。
        mean:样本均值
        variance:样本方差
        offset:样本偏移,None或一个向量,添加到归一化中
        scale:缩放(默认为1),None或一个向量,添加到归一化中
        :param x:
        :param is_training:
        :param decay:
        :param epsilon:为了避免分母为0,添加的一个极小值
        :return:
        """

        def bn_train():
            # 计算输入的均值与方差
            batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2])
            # 计算训练阶段用于更新的均值和方差
            train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
            train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
            with tf.control_dependencies([train_mean, train_var]):
                # 在[train_mean, train_var]执行之后,下面的才执行
                return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon)

        def bn_inference():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon)
        """
        tf.shape(x)返回的是一个tensor。要想知道是多少,必须通过sess.run()
        x.get_shape()返回的是元组,需要通过as_list()的操作转换成list,x必须是tensor
        x:[batch, height, width, channels or kernels],则dim就是channels的值,图像数据的第三维
        """
        dim = x.get_shape().as_list()[-1]
        beta = tf.get_variable(
            name='beta',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.truncated_normal_initializer(stddev=0.0),
            trainable=True)
        scale = tf.get_variable(
            name='scale',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.truncated_normal_initializer(stddev=0.1),
            trainable=True)
        pop_mean = tf.get_variable(
            name='pop_mean',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.constant_initializer(0.0),
            trainable=False)
        pop_var = tf.get_variable(
            name='pop_var',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.constant_initializer(1.0),
            trainable=False)
        # tf.cond()类似于问号表达式
        return tf.cond(is_training, bn_train, bn_inference)
    
    def flatten_layer(self, x):
        """
        图像矩阵转换为一个向量,有batch_size个这种向量
        :param x:
        :return:
        """
        input_shape = x.get_shape().as_list()
        dim = input_shape[1] * input_shape[2] * input_shape[3]  # 一张图片,三个维度上的总数据量
        # 不同维度进行交换
        transposed = tf.transpose(x, (0, 3, 1, 2))
        return tf.reshape(transposed, [-1, dim])

    def full_connection_layer(self, x, out_dim):
        # in_dim其实是前一层网络的输出大小
        in_dim = x.get_shape().as_list()[-1]
        W = tf.get_variable(
            name='weight',
            shape=[in_dim, out_dim],
            dtype=tf.float32,
            initializer=tf.truncated_normal_initializer(stddev=0.1),
            trainable=True)
        b = tf.get_variable(
            name='bias',
            shape=[out_dim],
            dtype=tf.float32,
            initializer=tf.constant_initializer(0.0),
            trainable=True)
        return tf.add(tf.matmul(x, W), b)

    def generator(self, x, is_training):
        with tf.variable_scope('generator'):
            with tf.variable_scope('conv1'):
                x = self.conv_layer(x, [5, 5, 3, 64], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv2'):
                x = self.conv_layer(x, [3, 3, 64, 128], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv3'):
                x = self.conv_layer(x, [3, 3, 128, 128], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv4'):
                x = self.conv_layer(x, [3, 3, 128, 256], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv5'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv6'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated1'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated2'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 4)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated3'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 8)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated4'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 16)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv7'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv8'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('deconv1'):
                x = self.deconv_layer(x, [4, 4, 128, 256], [self.batch_size, 64, 64, 128], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv9'):
                x = self.conv_layer(x, [3, 3, 128, 128], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('deconv2'):
                x = self.deconv_layer(x, [4, 4, 64, 128], [self.batch_size, 128, 128, 64], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv10'):
                x = self.conv_layer(x, [3, 3, 64, 32], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv11'):
                x = self.conv_layer(x, [3, 3, 32, 3], 1)
                x = tf.nn.tanh(x)
        # 输出图像尺寸 128 * 128
        return x

    def discriminator(self, x, local_x, reuse):
        def global_discriminator(x):
            is_training = tf.constant(True)
            with tf.variable_scope('global'):
                # 因为我们使用image_size = 128,原文是256,所以这里的卷积也少一层
                with tf.variable_scope('conv1'):
                    x = self.conv_layer(x, [5, 5, 3, 64], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv2'):
                    x = self.conv_layer(x, [5, 5, 64, 128], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv3'):
                    x = self.conv_layer(x, [5, 5, 128, 256], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv4'):
                    x = self.conv_layer(x, [5, 5, 256, 512], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv5'):
                    x = self.conv_layer(x, [5, 5, 512, 512], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('fc'):
                    x = self.flatten_layer(x)
                    x = self.full_connection_layer(x, 1024)
            return x

        def local_discriminator(x):
            is_training = tf.constant(True)
            with tf.variable_scope('local'):
                # 原文LOCAL_SIZE = 128,我们取64,所以这部分网络结构也少一层卷积
                with tf.variable_scope('conv1'):
                    x = self.conv_layer(x, [5, 5, 3, 64], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv2'):
                    x = self.conv_layer(x, [5, 5, 64, 128], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv3'):
                    x = self.conv_layer(x, [5, 5, 128, 256], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv4'):
                    x = self.conv_layer(x, [5, 5, 256, 512], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('fc'):
                    x = self.flatten_layer(x)
                    # 全连接层输入512维的向量,输出1024维的向量,并且没有激活函数
                    x = self.full_connection_layer(x, 1024)
            return x

        with tf.variable_scope('discriminator', reuse=reuse):
            """
            reuse参数:
            True: 参数空间使用reuse 模式,即该空间下的所有tf.get_variable()函数将直接获取已经创建的变量,
            如果参数不存在tf.get_variable()函数将会报错。
            AUTO_REUSE:若参数空间的参数不存在就创建他们,如果已经存在就直接获取它们。
            None 或者False 这里创建函数tf.get_variable()函数只能创建新的变量,当同名变量已经存在时,函数就报错
            * reuse(重用)标志是有继承性的:如果我们打开一个重用范围,那么它的所有子范围也会重用。
            """
            global_output = global_discriminator(x)
            local_output = local_discriminator(local_x)
            with tf.variable_scope('concatenation'):
                output = tf.concat((global_output, local_output), 1)
                output = self.full_connection_layer(output, 1)

        return output

    def calc_g_loss(self, x, completion):
        # 补全网络用到的损失函数,用于比较生成网络得到的图像和原图的差别大小
        loss = tf.nn.l2_loss(x - completion)
        return tf.reduce_mean(loss)

    def calc_d_loss(self, real, fake):
        # 判别网络损失函数,二分类问题
        alpha = 4e-4  # 约0.073
        # tf.ones_like(real)创建一个将real设置为1的张量.
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real, labels=tf.ones_like(real)))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake, labels=tf.zeros_like(fake)))
        return tf.add(d_loss_real, d_loss_fake) * alpha

train.py

import numpy as np
import tensorflow as tf
import os
import cv2
import tqdm
from network_build import Network

"""
如果运行过程中出现:
An error ocurred while starting the kernel
2019 20:27:22.601831: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this 
TensorFlow binary was not compiled to use: AVX2
大概意思是:你的CPU支持AVX扩展,但是你安装的TensorFlow版本无法编译使用。

那为什么会出现这种警告呢?
由于tensorflow默认分布是在没有CPU扩展的情况下构建的,例如SSE4.1,SSE4.2,AVX,AVX2,FMA等。默认版本(来自pip install 
tensorflow的版本)旨在与尽可能多的CPU兼容。另一个观点是,即使使用这些扩展名,CPU的速度也要比GPU慢很多,并且期望在GPU上执行中型和大型机器学习培训。

如果你有一个GPU,你不应该关心AVX的支持,因为大多数昂贵的操作将被分派到一个GPU设备上(除非明确地设置)。在这种情况下,您可以简单地忽略此警告:
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

BATCH_SIZE = 10
IMAGE_SIZE = 128
LOCAL_SIZE = 64
HOLE_MIN = 24
HOLE_MAX = 48
LEARNING_RATE = 1e-3
PRETRAIN_EPOCH = 100


def load(dir_=r'D:\demos\image processing\demo\data_my'):
    x_train = np.load(os.path.join(dir_, 'x_train.npy'))
    x_test = np.load(os.path.join(dir_, 'x_test.npy'))
    return x_train, x_test


def get_points():
    points = []
    mask = []
    for i in range(BATCH_SIZE):
        x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)
        x2, y2 = np.array([x1, y1]) + LOCAL_SIZE
        points.append([x1, y1, x2, y2])

        w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)
        p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)
        q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)
        p2 = p1 + w
        q2 = q1 + h

        m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
        m[q1:q2 + 1, p1:p2 + 1] = 1
        mask.append(m)
    # points是一个能表示大小为LOCAL_SIZE * LOCAL_SIZE的区域,mask的大小是IMAGE_SIZE * IMAGE_SIZE,但里面只有大小为(q2-q1)*
    # (p2-p1)的区域里面是1,其他部分全是0,并且这部分区域在points表示的区域内部
    return np.array(points), np.array(mask)


def train():
    """
    tf.reset_default_graph函数用于清除默认图形堆栈并重置全局默认图形。

    注意:默认图形是当前线程的一个属性。该tf.reset_default_graph函数只适用于当前线程。当一个tf.Session或者tf.InteractiveSession
    激活时调用这个函数会导致未定义的行为。调用此函数后使用任何以前创建的tf.Operation或tf.Tensor对象将导致未定义的行为。
    可能引发的异常:
    AssertionError:如果在嵌套图中调用此函数则会引发此异常。
    Clears the default graph stack and resets the global default graph.
    """
    tf.reset_default_graph()
    x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="x")
    mask = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1], name="mask")
    local_x = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3], name="local_x")
    global_completion = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="global_completion")
    local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3], name="local_completion")
    is_training = tf.placeholder(tf.bool, [], name="is_training")

    model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
    # global_step在滑动平均、优化器、指数衰减学习率等方面都有用到,这个变量的实际意义非常好理解:代表全局步数,比如在多少步该进行
    # 什么操作,现在神经网络训练到多少轮等等,类似于一个钟表。global_step的初始化值是0损失函数优化器的minimize()中global_step=
    # global_steps能够提供global_step每训练一个batch就加1的操作。
    global_step = tf.Variable(0, name='global_step', trainable=False)
    epoch = tf.Variable(0, name='epoch', trainable=False)

    opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
    # tf.train.Optimizer.minimize:添加操作节点,用于最小化loss,并更新var_list
    # 该函数是简单的合并了compute_gradients()与apply_gradients()函数
    # 返回为一个优化更新后的var_list,如果global_step非None,该操作还会为global_step做自增操作
    g_train_op = opt.minimize(model.g_loss, global_step=global_step, var_list=model.g_variables)
    d_train_op = opt.minimize(model.d_loss, global_step=global_step, var_list=model.d_variables)

    # 加载数据
    x_train, x_test = load()
    # 将图像中的每个数据归一化到 [-1, 1] 内
    x_train = np.array([a / 127.5 - 1 for a in x_train])
    x_test = np.array([a / 127.5 - 1 for a in x_test])

    # 一个epoch需要循环多少次
    step_num = int(len(x_train) / BATCH_SIZE)

    
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)

        # 加载预训练好的模型,加快训练
        if tf.train.get_checkpoint_state('../backup'):
            saver = tf.train.Saver()
            saver.restore(sess, '../backup/latest')

        while True:
            # 每循环一次,epoch + 1
            sess.run(tf.assign(epoch, tf.add(epoch, 1)))
            print('epoch: {}'.format(sess.run(epoch)))

            # 每循环一次数据集,打乱一次数据集中的数据
            np.random.shuffle(x_train)

            # Completion
            # 先训练图像补全网络 PRETRAIN_EPOCH = 100 次
            # 注意:取每一个tensor变量的值,都要 run 一下
            if sess.run(epoch) <= PRETRAIN_EPOCH:
                g_loss_value = 0
                points_batch, mask_batch = get_points()
                # tqdm 是 Python 进度条库,可以在 Python 长循环中添加一个进度提示信息用法:tqdm(iterator)
                for i in tqdm.tqdm(range(step_num)):
                    # 一个epoch循环step_num次,每次从训练集中取出一批BATCH_SIZE大小的数据
                    x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

                    _, g_loss = sess.run([g_train_op, model.g_loss],
                                         feed_dict={x: x_batch, mask: mask_batch, is_training: True})
                    g_loss_value += g_loss
                print("epoch:{}".format(sess.run(epoch)))
                print("Completion loss: {}".format(g_loss_value))

                np.random.shuffle(x_test)
                # 因为在制作mask的时候,选择一次制作的数量是BATCH_SIZE,所以从测试集中取出BATCH_SIZE个数据进行测试
                x_batch = x_test[:BATCH_SIZE]
                completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
                print("completion[0].shape:", completion[0].shape)
                # 恢复图像
                sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
                cv2.imwrite('./output1/{}.jpg'.format("{0:06d}".format(sess.run(epoch))),
                            cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))

                saver = tf.train.Saver()
                saver.save(sess, './backup/latest')
                if sess.run(epoch) == PRETRAIN_EPOCH:
                    saver.save(sess, './backup/pretrained')
                    # Discrimitation

            # Discrimitation
            # 如果epoch > 100,生成网络和判别网络一起训练
            else:
                g_loss_value = 0
                d_loss_value = 0
                points_batch, mask_batch = get_points()
                for i in tqdm.tqdm(range(step_num)):
                    x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

                    # 训练生成网络
                    _, g_loss, completion = sess.run([g_train_op, model.g_loss, model.completion],
                                                     feed_dict={x: x_batch, mask: mask_batch, is_training: True})
                    g_loss_value += g_loss

                    local_x_batch = []
                    local_completion_batch = []
                    # 得到一个BATCH_SIZE中原始图片和生成网络生成的图片的local区域
                    for i in range(BATCH_SIZE):
                        x1, y1, x2, y2 = points_batch[i]
                        local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])
                        local_completion_batch.append(completion[i][y1:y2, x1:x2, :])
                    local_x_batch = np.array(local_x_batch)
                    local_completion_batch = np.array(local_completion_batch)

                    """
                    d_train_op用到了d_loss,d_loss来自于calc_d_loss,calc_d_loss有real和fake两个参数,
                    real来自于discriminator(x, local_x, reuse=False)
                    fake来自于discriminator(global_completion, local_completion, reuse=True)
                    所以feed_dict的参数包括x、local_x、global_completion、local_completion,以及mask
                    
                    """
                    _, d_loss = sess.run(
                        [d_train_op, model.d_loss],
                        feed_dict={x: x_batch, mask: mask_batch, local_x: local_x_batch, global_completion: completion,
                                   local_completion: local_completion_batch, is_training: True})
                    d_loss_value += d_loss

                print("epoch:{}".format(sess.run(epoch)))
                print('Completion loss: {}'.format(g_loss_value))
                print('Discriminator loss: {}'.format(d_loss_value))

                np.random.shuffle(x_test)
                x_batch = x_test[:BATCH_SIZE]
                completion = sess.run(model.completion,
                                      feed_dict={x: x_batch, mask: mask_batch, is_training: False})
                sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
                cv2.imwrite('./output2/{}.jpg'.format("{0:06d}".format(sess.run(epoch))),
                            cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))

                saver = tf.train.Saver()
                saver.save(sess, './backup/latest', write_meta_graph=False)


if __name__ == '__main__':
    train()

你可能感兴趣的:(《Globally and locally consistent image completion》论文复现)