SRGAN_tensorflow包含两个网络,分别是SRGAN网络和VGG网络。其中,SRGAN网络由生成器网络generator和判决器网络discriminator组成。根据原始论文(Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network)框架,其训练过程如下;
本文尽可能以通俗易懂的方式介绍下超分辨TensorFlow实现的代码。从main脚本文件开始,首先要需要输入命令行参数作为整个srgan的超参数。条件语句用来检测某些关键参数如输入输出路径是否传入或存在。在这些准备好之后,程序就要开始它的核心工作了,这里核心工作共有三个,test mode(利用训练好的generator模型在LR上测试),inference mode(利用训练好的generator在任意数据集上测试),以及train mode(训练模式包含两个task,SRResnet和SRGAN ,每个task下的perceptual_mode 有三种模式可选,VGG54,VGG22,MSE)。在train mode模式下,上述过程1工作时,task=SRResnet,perceptual_mode=MSE;过程2工作时,task=SRGAN,perceptual_mode=MSE;过程3工作时,task=SRGAN,perceptual_mode=VGG54/22。每个模式的详细功能分别介绍如下。
1.test mode。
任务是把低分辨数据集放到训练好的模型上测试并保存结果。进入该模式第一步就是调用mode.py中的test_data_loader(FLAGS)函数读入数据,os.listdir() 方法用于返回路径FLAGS.input_dir_LR下包含的所有文件的名字的列表。将文件名列表分别和各自对应的具体路径拼接在一起,得到两个包含文件名路径的完整地址列表image_list_LR和image_list_HR。接着,利用for循环调用preprocess_test(name, mode)函数读入image_list_LR和image_list_HR对应的图像列表,得到image_LR(map(0,1))和image_HR(map(-1,1))。最终返回一个大的列表test_data,包含两个完整路径和两个完整图像列表。同时定义四个占位符,后面使用时直接给占位符赋值。
inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
targets_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='targets_raw')
path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')
path_HR = tf.placeholder(tf.string, shape=[], name='path_HR')
接下来根据工作模式将数据送入generator,这里的inputs_raw占位符在session中会被赋以真实值test_data.inputs
with tf.variable_scope('generator'):
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
else:
raise NotImplementedError('Unknown task!!')
generator是实现超分网络的模型,主要由多个残差块和卷积层组成,详细信息如下:
# Definition of the generator
def generator(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
# Check the flag
if FLAGS is None:
raise ValueError('No FLAGS is provided for generator')
# The Bx residual blocks
def residual_block(inputs, output_channel, stride, scope):
with tf.variable_scope(scope):
#3x3kernel
net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1')
net = batchnorm(net, FLAGS.is_training)
net = prelu_tf(net)
#3x3kernel
net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2')
net = batchnorm(net, FLAGS.is_training)
#skip connection
net = net + inputs
return net
with tf.variable_scope('generator_unit', reuse=reuse):
# The input layer
with tf.variable_scope('input_stage'):
# kernel size 9x9 kernel count(feature map) 64 stride 1
net = conv2(gen_inputs, 9, 64, 1, scope='conv')
net = prelu_tf(net)
stage1_output = net
# The residual block parts
for i in range(1, FLAGS.num_resblock+1 , 1):
name_scope = 'resblock_%d'%(i)
net = residual_block(net, 64, 1, name_scope)
with tf.variable_scope('resblock_output'):
net = conv2(net, 3, 64, 1, use_bias=False, scope='conv')
net = batchnorm(net, FLAGS.is_training)
net = net + stage1_output
with tf.variable_scope('subpixelconv_stage1'):
net = conv2(net, 3, 256, 1, scope='conv')
net = pixelShuffler(net, scale=2)
net = prelu_tf(net)
with tf.variable_scope('subpixelconv_stage2'):
net = conv2(net, 3, 256, 1, scope='conv')
net = pixelShuffler(net, scale=2)
net = prelu_tf(net)
with tf.variable_scope('output_stage'):
net = conv2(net, 9, gen_output_channels, 1, scope='conv')
return net
直接读入的数据先要经过conver image进行map处理,处理之后然后转为uint8格式得到converted_inputs,converted_targets,converted_outputs。
with tf.name_scope('convert_image'):
# Deprocess the images outputed from the model
inputs = deprocessLR(inputs_raw)#不做处理return tf.identity(image)
targets = deprocess(targets_raw)# #map处理 [-1, 1] => [0, 1](return (image + 1) / 2)
outputs = deprocess(gen_output)
# Convert back to uint8
converted_inputs = tf.image.convert_image_dtype(inputs, dtype=tf.uint8, saturate=True)
converted_targets = tf.image.convert_image_dtype(targets, dtype=tf.uint8, saturate=True)
converted_outputs = tf.image.convert_image_dtype(outputs, dtype=tf.uint8, saturate=True)
构造fetch,包含信息如下,在session.run中使用
with tf.name_scope('encode_image'):
save_fetch = {
"path_LR": path_LR,
"path_HR": path_HR,
"inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name='input_pngs'),
"outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name='output_pngs'),
"targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name='target_pngs')
}
# Define the weight initiallizer (In test, we only need to restore the weight of the generator)
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
weight_initiallizer = tf.train.Saver(var_list)
# Define the initialization operation
init_op = tf.global_variables_initializer()
指定session的配置信息,开启session。每次循环从输入数据(LR)中拿出一张图送入训练好的生成器网络,将save_fetch的信息保存下来作为测试结果。
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# Load the pretrained model
print('Loading weights from the pre-trained model')
weight_initiallizer.restore(sess, FLAGS.checkpoint)
max_iter = len(test_data.inputs)
print('Evaluation starts!!')
for i in range(max_iter):
input_im = np.array([test_data.inputs[i]]).astype(np.float32)
target_im = np.array([test_data.targets[i]]).astype(np.float32)
path_lr = test_data.paths_LR[i]
path_hr = test_data.paths_HR[i]
results = sess.run(save_fetch, feed_dict={inputs_raw: input_im, targets_raw: target_im,
path_LR: path_lr, path_HR: path_hr})
filesets = save_images(results, FLAGS)
for i, f in enumerate(filesets):
print('evaluate image', f['name'])
2.inference mode
任务是利用train好的网络对任意输入图像(可以使用自己的图像)进行超分处理并保存结果。
inference_data = inference_data_loader(FLAGS)#调用model.py中的inference_data_loader函数载入数据。
其中inference_data_loader函数的详细定义如下。对输入数据简单处理后返回一个列表,包含完整的图像地址和对应的图像。
# The inference data loader. Allow input image with different size
def inference_data_loader(FLAGS):
# Get the image name list
if (FLAGS.input_dir_LR == 'None'):
raise ValueError('Input directory is not provided')
if not os.path.exists(FLAGS.input_dir_LR):
raise ValueError('Input directory not found')
image_list_LR_temp = os.listdir(FLAGS.input_dir_LR)
image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp if _.split('.')[-1] == 'png']
# Read in and preprocess the images
def preprocess_test(name):
im = sic.imread(name, mode="RGB").astype(np.float32)
# check grayscale image
if im.shape[-1] != 3:
h, w = im.shape
temp = np.empty((h, w, 3), dtype=np.uint8)
temp[:, :, :] = im[:, :, np.newaxis]
im = temp.copy()
im = im / np.max(im)
return im
image_LR = [preprocess_test(_) for _ in image_list_LR]
# Push path and image into a list
Data = collections.namedtuple('Data', 'paths_LR, inputs')
return Data(
paths_LR=image_list_LR,
inputs=image_LR
)
后面的过程和test模式下的操作完全类似。只是test模式保存的信息包括了真实图片,可以对比超分前后和真实图片的差别。推理模式只保存超分前后的图片,和真实图的区别无从比较。
3.train mode
该模式根据输入的真实高分图片和降采样后的图片分别训练discriminator和generator,得到完整的SRGAN网络。
data = data_loader(FLAGS)
data_loader()函数具体信息如下:图像载入阶段完成图像文件名队列操作,张量转换;图像加载后预处理,解码,类型调整,数据增强(裁剪、缩放,翻转、扭曲),使输入网络训练信息多样化,缓解过拟合。训练阶段,使用tf.train.shuffle_batch:将队列output[0], output[1](output是数据初次载入值)中未做任何处理的数据打乱后,读取出来得到paths_LR_batch, paths_HR_batch,将增强后的数据input_images, target_images打乱后读取出来得到inputs_batch, targets_batch。根据输入图片数和batch数计算每个epoch中包含的step,执行一个batch算作一步。最后返回一个总的列表,包含四组数据batch, #总的图像的个数image_count=len(image_list_LR),和一个epoch内batch的个数 steps_per_epoch=steps_per_epoch。
# Define the dataloader
def data_loader(FLAGS):
with tf.device('/cpu:0'):
# Define the returned data batches
Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets, image_count, steps_per_epoch')
#Check the input directory
if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'):
raise ValueError('Input directory is not provided')
if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)):
raise ValueError('Input directory not found')
#os.listdir(path) 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中。循环中用_代替仅获取值而已
image_list_LR = os.listdir(FLAGS.input_dir_LR)
#所有.png组成的列表
image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')]
#通过列表长度判空进而抛出异常
if len(image_list_LR)==0:
raise Exception('No png files in the input directory')
#List的成员函数sort对给定的List 进行排序
image_list_LR_temp = sorted(image_list_LR)
#文件路径拼接,路径+文件名组合成完整路径,LR按顺序,HR不按顺序
image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp]
image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp]
#用于将不同数据变成张量:比如可以让数组变成张量、也可以让列表变成张量。
# dtype: Optional element type for the returned tensor. If missing, the(可以指定转化成tensor后输出的数据类型) type is inferred from the type of `value`.
image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string)
image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string)
#当reuse为False或者None时(这也是默认值),同一个tf.variable_scope下面的变量名不能相同;
with tf.variable_scope('load_image'):
# define the image list queue
# image_list_LR_queue = tf.train.string_input_producer(image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity)
# image_list_HR_queue = tf.train.string_input_producer(image_list_HR, shuffle=False, capacity=FLAGS.name_queue_capacity)
#print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train'))
#tf.train.slice_input_producer:tensor生成器,这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。
#参数shuffle: bool类型,设置是否打乱样本的顺序。如果shuffle=True,生成的样本顺序就被打乱了,
#如果shuffle=False,样本顺序未被打乱,需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本。参数capacity:设置tensor列表的容量。
output = tf.train.slice_input_producer([image_list_LR_tensor, image_list_HR_tensor],
shuffle=False, capacity=FLAGS.name_queue_capacity)
# Reading and decode the images
## reader从文件名队列output中读数据。对应的方法是reader.read
reader = tf.WholeFileReader(name='image_reader')
#提取LR图片内容和HR图片内容,一定注意数据之间的转化;
image_LR = tf.read_file(output[0])
image_HR = tf.read_file(output[1])
#channels必须要制定
input_image_LR = tf.image.decode_png(image_LR, channels=3)
input_image_HR = tf.image.decode_png(image_HR, channels=3)
#图片数据进行转化,此处为了显示而转化
input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32)
input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32)
#如果tf.shape(input_image_LR)[2]和3不一致就抛出异常
assertion = tf.assert_equal(tf.shape(input_image_LR)[2], 3, message="image does not have 3 channels")
#tf.identity返回一个和输入的 tensor 大小和数值都一样的 tensor ,类似于 y=x 操作,主要的用途就是更好的控制在不同设备间传递变量的值
with tf.control_dependencies([assertion]):
input_image_LR = tf.identity(input_image_LR)
input_image_HR = tf.identity(input_image_HR)
# Normalize the low resolution image to [0, 1], high resolution to [-1, 1]
a_image = preprocessLR(input_image_LR)
b_image = preprocess(input_image_HR)
inputs, targets = [a_image, b_image]
# The data augmentation part
with tf.name_scope('data_preprocessing'):
with tf.name_scope('random_crop'):
# Check whether perform crop
if (FLAGS.random_crop is True) and FLAGS.mode == 'train':
print('[Config] Use random crop')
# Set the shape of the input image. the target will have 4X size
input_size = tf.shape(inputs)
target_size = tf.shape(targets)
#用于改变某个张量的数据类型 图像增强后的高宽
offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)),
dtype=tf.int32)
offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)),
dtype=tf.int32)
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
#tf.image.crop_to_bounding_box:将图像裁剪到指定的边界框(offset_h, offset_w)代表新图像左上角坐标FLAGS.crop_size x FLAGS.crop_size:结果的高度和高度
inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
FLAGS.crop_size)
targets = tf.image.crop_to_bounding_box(targets, offset_h*4, offset_w*4, FLAGS.crop_size*4,
FLAGS.crop_size*4)
elif FLAGS.task == 'denoise':
inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
FLAGS.crop_size)
targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w,
FLAGS.crop_size, FLAGS.crop_size)
# Do not perform crop裁剪
else:
inputs = tf.identity(inputs)
targets = tf.identity(targets)
with tf.variable_scope('random_flip'):
# Check for random flip翻转:
if (FLAGS.flip is True) and (FLAGS.mode == 'train'):
print('[Config] Use random flip')
# Produce the decision of random flip产生均匀分布随机数
decision = tf.random_uniform([], 0, 1, dtype=tf.float32)
input_images = random_flip(inputs, decision)
target_images = random_flip(targets, decision)
else:
input_images = tf.identity(inputs)
target_images = tf.identity(targets)
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
target_images.set_shape([FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
elif FLAGS.task == 'denoise':
input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
target_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
# tf.train.shuffle_batch:将队列output[0], output[1]中数据打乱后,再读取出来,因此队列中剩下的数据也是乱序的,队头也是一直在补充
# 读取一个文件并且加载一个张量中的batch_size行 capacity是队列的长度
# tensor_list:入队的张量列表 batch_size:表示进行一次批处理的tensors数量 capacity:一个整数,队列中的最大的元素数
# min_after_dequeue:当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别
# num_threads:值大于1,使用多个线程在tensor_list中读取文件,这样保证了同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件
if FLAGS.mode == 'train':
paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([output[0], output[1], input_images, target_images],
batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity+40*FLAGS.batch_size,
min_after_dequeue=FLAGS.image_queue_capacity, num_threads=FLAGS.queue_thread)
else:
paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.batch([output[0], output[1], input_images, target_images],
batch_size=FLAGS.batch_size, num_threads=FLAGS.queue_thread, allow_smaller_final_batch=True)
steps_per_epoch = int(math.ceil(len(image_list_LR) / FLAGS.batch_size))
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
elif FLAGS.task == 'denoise':
inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
return Data(
# batch的文件名队列
paths_LR=paths_LR_batch,
paths_HR=paths_HR_batch,
# 所有batch对应的输入样本
inputs=inputs_batch,
targets=targets_batch,
#总的图像的个数
image_count=len(image_list_LR),
#一个epoch内batch的个数
steps_per_epoch=steps_per_epoch
)
调用SRGAN网络Net = SRGAN(data.inputs, data.targets, FLAGS)。得到一组返回值,有了这些返回值就可以开启sv,使用summary显示训练过程,在会话中,summary信息连同其他关键信息被放入fetches,执行results = sess.run(fetches),然后通过result即可打印出训练过程的关键信息。