SRGAN_tensorflow_code

SRGAN_tensorflow包含两个网络,分别是SRGAN网络和VGG网络。其中,SRGAN网络由生成器网络generator和判决器网络discriminator组成。根据原始论文(Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network)框架,其训练过程如下;

  1. 在SRResnet task下训练SRResnet网络,得到最初的generator网络。输入是HR降采样后的低分LR图片,输出是低分LR经过超分后产生得超分图片SR,HR作为ground truth和SR一起衡量content loss。循环100万次;
  2. 保留1中generator网络所有参数,随机初始化discriminator网络参数,训练SRGAN网络。使用MSE衡量生成器内容损失content loss,使用交叉验证衡量生成器对抗损失adversarial loss(由生成图片输入判决器判决结果决定),使用交叉验证衡量判决器的discriminator loss(包含discrim_fake_loss和discrim_real_loss,前者由生成图片经过判决器的输出结果决定,后者由真实图片经过判决器的输出结果决定)。循环50万次;
  3. 保留2中SRGAN网络所有参数,引入VGG网络。使用VGG网络分别提取真实图片HR和生成图片SR的特征,计算两个特征差作为损失函数,更好的保留图像细节信息。循环20万次

本文尽可能以通俗易懂的方式介绍下超分辨TensorFlow实现的代码。从main脚本文件开始,首先要需要输入命令行参数作为整个srgan的超参数。条件语句用来检测某些关键参数如输入输出路径是否传入或存在。在这些准备好之后,程序就要开始它的核心工作了,这里核心工作共有三个,test mode(利用训练好的generator模型在LR上测试),inference mode(利用训练好的generator在任意数据集上测试),以及train mode(训练模式包含两个task,SRResnet和SRGAN ,每个task下的perceptual_mode 有三种模式可选,VGG54,VGG22,MSE)。在train mode模式下,上述过程1工作时,task=SRResnet,perceptual_mode=MSE;过程2工作时,task=SRGAN,perceptual_mode=MSE;过程3工作时,task=SRGAN,perceptual_mode=VGG54/22。每个模式的详细功能分别介绍如下。

1.test mode。

任务是把低分辨数据集放到训练好的模型上测试并保存结果。进入该模式第一步就是调用mode.py中的test_data_loader(FLAGS)函数读入数据,os.listdir() 方法用于返回路径FLAGS.input_dir_LR下包含的所有文件的名字的列表。将文件名列表分别和各自对应的具体路径拼接在一起,得到两个包含文件名路径的完整地址列表image_list_LR和image_list_HR。接着,利用for循环调用preprocess_test(name, mode)函数读入image_list_LR和image_list_HR对应的图像列表,得到image_LR(map(0,1))和image_HR(map(-1,1))。最终返回一个大的列表test_data,包含两个完整路径和两个完整图像列表。同时定义四个占位符,后面使用时直接给占位符赋值。

    inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
    targets_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='targets_raw')
    path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')
    path_HR = tf.placeholder(tf.string, shape=[], name='path_HR')

接下来根据工作模式将数据送入generator,这里的inputs_raw占位符在session中会被赋以真实值test_data.inputs

    with tf.variable_scope('generator'):
        if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
            gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
        else:
            raise NotImplementedError('Unknown task!!')

generator是实现超分网络的模型,主要由多个残差块和卷积层组成,详细信息如下:

# Definition of the generator
def generator(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
    # Check the flag
    if FLAGS is None:
        raise  ValueError('No FLAGS is provided for generator')

    # The Bx residual blocks
    def residual_block(inputs, output_channel, stride, scope):
        with tf.variable_scope(scope):
            #3x3kernel
            net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1')
            net = batchnorm(net, FLAGS.is_training)
            net = prelu_tf(net)
            #3x3kernel
            net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2')
            net = batchnorm(net, FLAGS.is_training)
            #skip connection
            net = net + inputs

        return net


    with tf.variable_scope('generator_unit', reuse=reuse):
        # The input layer
        with tf.variable_scope('input_stage'):
            # kernel size 9x9 kernel count(feature map) 64 stride 1
            net = conv2(gen_inputs, 9, 64, 1, scope='conv')
            net = prelu_tf(net)

        stage1_output = net

        # The residual block parts
        for i in range(1, FLAGS.num_resblock+1 , 1):
            name_scope = 'resblock_%d'%(i)
            net = residual_block(net, 64, 1, name_scope)

        with tf.variable_scope('resblock_output'):
            net = conv2(net, 3, 64, 1, use_bias=False, scope='conv')
            net = batchnorm(net, FLAGS.is_training)

        net = net + stage1_output

        with tf.variable_scope('subpixelconv_stage1'):
            net = conv2(net, 3, 256, 1, scope='conv')
            net = pixelShuffler(net, scale=2)
            net = prelu_tf(net)

        with tf.variable_scope('subpixelconv_stage2'):
            net = conv2(net, 3, 256, 1, scope='conv')
            net = pixelShuffler(net, scale=2)
            net = prelu_tf(net)

        with tf.variable_scope('output_stage'):
            net = conv2(net, 9, gen_output_channels, 1, scope='conv')

    return net

直接读入的数据先要经过conver image进行map处理,处理之后然后转为uint8格式得到converted_inputs,converted_targets,converted_outputs。

    with tf.name_scope('convert_image'):
        # Deprocess the images outputed from the model
        inputs = deprocessLR(inputs_raw)#不做处理return tf.identity(image)
        targets = deprocess(targets_raw)# #map处理 [-1, 1] => [0, 1](return (image + 1) / 2)
        outputs = deprocess(gen_output)

        # Convert back to uint8
        converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)

构造fetch,包含信息如下,在session.run中使用

    with tf.name_scope('encode_image'):
        save_fetch = {
            "path_LR": path_LR,
            "path_HR": path_HR,
            "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name='input_pngs'),
            "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name='output_pngs'),
            "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name='target_pngs')
        }
    # Define the weight initiallizer (In test, we only need to restore the weight of the generator)
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
    weight_initiallizer = tf.train.Saver(var_list)

    # Define the initialization operation
    init_op = tf.global_variables_initializer()

指定session的配置信息,开启session。每次循环从输入数据(LR)中拿出一张图送入训练好的生成器网络,将save_fetch的信息保存下来作为测试结果。

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # Load the pretrained model
        print('Loading weights from the pre-trained model')
        weight_initiallizer.restore(sess, FLAGS.checkpoint)

        max_iter = len(test_data.inputs)
        print('Evaluation starts!!')
        for i in range(max_iter):
            input_im = np.array([test_data.inputs[i]]).astype(np.float32)
            target_im = np.array([test_data.targets[i]]).astype(np.float32)
            path_lr = test_data.paths_LR[i]
            path_hr = test_data.paths_HR[i]
            results = sess.run(save_fetch, feed_dict={inputs_raw: input_im, targets_raw: target_im,
                                                      path_LR: path_lr, path_HR: path_hr})
            filesets = save_images(results, FLAGS)
            for i, f in enumerate(filesets):
                print('evaluate image', f['name'])

2.inference mode 

任务是利用train好的网络对任意输入图像(可以使用自己的图像)进行超分处理并保存结果。

inference_data = inference_data_loader(FLAGS)#调用model.py中的inference_data_loader函数载入数据。

其中inference_data_loader函数的详细定义如下。对输入数据简单处理后返回一个列表,包含完整的图像地址和对应的图像。

# The inference data loader. Allow input image with different size
def inference_data_loader(FLAGS):
    # Get the image name list
    if (FLAGS.input_dir_LR == 'None'):
        raise ValueError('Input directory is not provided')

    if not os.path.exists(FLAGS.input_dir_LR):
        raise ValueError('Input directory not found')

    image_list_LR_temp = os.listdir(FLAGS.input_dir_LR)
    image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png']

    # Read in and preprocess the images
    def preprocess_test(name):
        im = sic.imread(name, mode="RGB").astype(np.float32)
        # check grayscale image
        if im.shape[-1] != 3:
            h, w = im.shape
            temp = np.empty((h, w, 3), dtype=np.uint8)
            temp[:, :, :] = im[:, :, np.newaxis]
            im = temp.copy()
        im = im / np.max(im)

        return im

    image_LR = [preprocess_test(_) for _ in image_list_LR]

    # Push path and image into a list
    Data = collections.namedtuple('Data', 'paths_LR, inputs')


    return Data(
        paths_LR=image_list_LR,
        inputs=image_LR
    )

后面的过程和test模式下的操作完全类似。只是test模式保存的信息包括了真实图片,可以对比超分前后和真实图片的差别。推理模式只保存超分前后的图片,和真实图的区别无从比较。

3.train mode

该模式根据输入的真实高分图片和降采样后的图片分别训练discriminator和generator,得到完整的SRGAN网络。

 data = data_loader(FLAGS)

data_loader()函数具体信息如下:图像载入阶段完成图像文件名队列操作,张量转换;图像加载后预处理,解码,类型调整,数据增强(裁剪、缩放,翻转、扭曲),使输入网络训练信息多样化,缓解过拟合。训练阶段,使用tf.train.shuffle_batch:将队列output[0], output[1](output是数据初次载入值)中未做任何处理的数据打乱后,读取出来得到paths_LR_batch, paths_HR_batch,将增强后的数据input_images, target_images打乱后读取出来得到inputs_batch, targets_batch。根据输入图片数和batch数计算每个epoch中包含的step,执行一个batch算作一步。最后返回一个总的列表,包含四组数据batch,        #总的图像的个数image_count=len(image_list_LR),和一个epoch内batch的个数 steps_per_epoch=steps_per_epoch。

# Define the dataloader
def data_loader(FLAGS):
    with tf.device('/cpu:0'):
        # Define the returned data batches
        Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets, image_count, steps_per_epoch')

        #Check the input directory
        if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'):
            raise ValueError('Input directory is not provided')

        if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)):
            raise ValueError('Input directory not found')
        #os.listdir(path) 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中。循环中用_代替仅获取值而已
        image_list_LR = os.listdir(FLAGS.input_dir_LR)
        #所有.png组成的列表
        image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')]
        #通过列表长度判空进而抛出异常
        if len(image_list_LR)==0:
            raise Exception('No png files in the input directory')
        #List的成员函数sort对给定的List 进行排序
        image_list_LR_temp = sorted(image_list_LR)
        #文件路径拼接,路径+文件名组合成完整路径,LR按顺序,HR不按顺序
        image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp]
        image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp]
        #用于将不同数据变成张量:比如可以让数组变成张量、也可以让列表变成张量。
        # dtype: Optional element type for the returned tensor. If missing, the(可以指定转化成tensor后输出的数据类型) type is inferred from the type of `value`.

        image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string)
        image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string)

        #当reuse为False或者None时(这也是默认值),同一个tf.variable_scope下面的变量名不能相同;
        with tf.variable_scope('load_image'):
            # define the image list queue
            # image_list_LR_queue = tf.train.string_input_producer(image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity)
            # image_list_HR_queue = tf.train.string_input_producer(image_list_HR, shuffle=False, capacity=FLAGS.name_queue_capacity)
            #print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train'))
            
            #tf.train.slice_input_producer:tensor生成器,这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。
            #参数shuffle: bool类型,设置是否打乱样本的顺序。如果shuffle=True,生成的样本顺序就被打乱了,
            #如果shuffle=False,样本顺序未被打乱,需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本。参数capacity:设置tensor列表的容量。
         
            output = tf.train.slice_input_producer([image_list_LR_tensor, image_list_HR_tensor],
                                                   shuffle=False, capacity=FLAGS.name_queue_capacity)

            # Reading and decode the images
            ## reader从文件名队列output中读数据。对应的方法是reader.read
            reader = tf.WholeFileReader(name='image_reader')
            #提取LR图片内容和HR图片内容,一定注意数据之间的转化;
            image_LR = tf.read_file(output[0])
            image_HR = tf.read_file(output[1])
            #channels必须要制定
            input_image_LR = tf.image.decode_png(image_LR, channels=3)
            input_image_HR = tf.image.decode_png(image_HR, channels=3)
            #图片数据进行转化,此处为了显示而转化
            input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32)
            input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32)
            #如果tf.shape(input_image_LR)[2]和3不一致就抛出异常 
            assertion = tf.assert_equal(tf.shape(input_image_LR)[2], 3, message="image does not have 3 channels")
            #tf.identity返回一个和输入的 tensor 大小和数值都一样的 tensor ,类似于 y=x 操作,主要的用途就是更好的控制在不同设备间传递变量的值
            with tf.control_dependencies([assertion]):
                input_image_LR = tf.identity(input_image_LR)
                input_image_HR = tf.identity(input_image_HR)

            # Normalize the low resolution image to [0, 1], high resolution to [-1, 1]
            a_image = preprocessLR(input_image_LR)
            b_image = preprocess(input_image_HR)

            inputs, targets = [a_image, b_image]

        # The data augmentation part
        with tf.name_scope('data_preprocessing'):
            with tf.name_scope('random_crop'):
                # Check whether perform crop
                if (FLAGS.random_crop is True) and FLAGS.mode == 'train':
                    print('[Config] Use random crop')
                    # Set the shape of the input image. the target will have 4X size
                    input_size = tf.shape(inputs)
                    target_size = tf.shape(targets)
                    #用于改变某个张量的数据类型 图像增强后的高宽
                    offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)),
                                       dtype=tf.int32)
                    offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)),
                                       dtype=tf.int32)

                    if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
                        #tf.image.crop_to_bounding_box:将图像裁剪到指定的边界框(offset_h, offset_w)代表新图像左上角坐标FLAGS.crop_size x FLAGS.crop_size:结果的高度和高度
                        inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
                                                               FLAGS.crop_size)
                        targets = tf.image.crop_to_bounding_box(targets, offset_h*4, offset_w*4, FLAGS.crop_size*4,
                                                                FLAGS.crop_size*4)
                    elif FLAGS.task == 'denoise':
                        inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
                                                               FLAGS.crop_size)
                        targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w,
                                                                FLAGS.crop_size, FLAGS.crop_size)
                # Do not perform crop裁剪
                else:
                    inputs = tf.identity(inputs)
                    targets = tf.identity(targets)

            with tf.variable_scope('random_flip'):
                # Check for random flip翻转:
                if (FLAGS.flip is True) and (FLAGS.mode == 'train'):
                    print('[Config] Use random flip')
                    # Produce the decision of random flip产生均匀分布随机数
                    decision = tf.random_uniform([], 0, 1, dtype=tf.float32)

                    input_images = random_flip(inputs, decision)
                    target_images = random_flip(targets, decision)
                else:
                    input_images = tf.identity(inputs)
                    target_images = tf.identity(targets)

            if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
                input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
                target_images.set_shape([FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
            elif FLAGS.task == 'denoise':
                input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
                target_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
        # tf.train.shuffle_batch:将队列output[0], output[1]中数据打乱后,再读取出来,因此队列中剩下的数据也是乱序的,队头也是一直在补充
        # 读取一个文件并且加载一个张量中的batch_size行 capacity是队列的长度
        # tensor_list:入队的张量列表 batch_size:表示进行一次批处理的tensors数量 capacity:一个整数,队列中的最大的元素数
        # min_after_dequeue:当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别
        # num_threads:值大于1,使用多个线程在tensor_list中读取文件,这样保证了同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件
        if FLAGS.mode == 'train':
            paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([output[0], output[1], input_images, target_images],
                                            batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity+40*FLAGS.batch_size,
                                            min_after_dequeue=FLAGS.image_queue_capacity, num_threads=FLAGS.queue_thread)
        else:
            paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.batch([output[0], output[1], input_images, target_images],
                                            batch_size=FLAGS.batch_size, num_threads=FLAGS.queue_thread, allow_smaller_final_batch=True)

        steps_per_epoch = int(math.ceil(len(image_list_LR) / FLAGS.batch_size))
        if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
            inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
            targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
        elif FLAGS.task == 'denoise':
            inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
            targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
    return Data(
        # batch的文件名队列
        paths_LR=paths_LR_batch,
        paths_HR=paths_HR_batch,
        # 所有batch对应的输入样本
        inputs=inputs_batch,
        targets=targets_batch,
        #总的图像的个数
        image_count=len(image_list_LR),
        #一个epoch内batch的个数
        steps_per_epoch=steps_per_epoch
    )

        调用SRGAN网络Net = SRGAN(data.inputs, data.targets, FLAGS)。得到一组返回值,有了这些返回值就可以开启sv,使用summary显示训练过程,在会话中,summary信息连同其他关键信息被放入fetches,执行results = sess.run(fetches),然后通过result即可打印出训练过程的关键信息。

你可能感兴趣的:(Python)