最难的还是train_model.py的编写
其实读论文理解起来没有那么难
但是用代码实现的话就要好多天,还不能保证没有bug...
train.py中最难的还是Loss function模块的编写
因为这也是论文的创新点,没有代码可以参考
train.py中,分成四个模块,定义参数、读入数据模块、loss模块、训练模块
各模块代码及注释如下:
######################
# define the parameter#
######################
tf.app.flags.DEFINE_string('loss_model', 'vgg_16', '损失网络模型名 ')
tf.app.flags.DEFINE_string('loss_model_file', 'loss_model_ckpt/vgg_16.ckpt', '损失网络ckpt文件路径 ')
tf.app.flags.DEFINE_integer('image_size', 256, '图像大小')
# style-transfer-model的ckpt相关
tf.app.flags.DEFINE_string("model_path", "transfer_model_ckpt", "风格ckpt文件路径")
tf.app.flags.DEFINE_string("model_name", "candy", "风格名")
tf.app.flags.DEFINE_string("model_file", "models.ckpt", "风格ckpt文件名")
#内容图片与风格图片
tf.app.flags.DEFINE_string("image_file", "srcImg/test.jpg", "输入模型的图片路径")
tf.app.flags.DEFINE_string("res_file", "resImg", "模型输出的图片保存目录")
tf.app.flags.DEFINE_string("style_image", "styleImg/candy.jpg", "风格图片的路径")
#损失函数权重参数
tf.app.flags.DEFINE_float('content_weight', 1.0, '内容损失函数权重')
tf.app.flags.DEFINE_float('style_weight', 100.0, '风格损失函数权重')
tf.app.flags.DEFINE_float('tv_weight', 0.00001, 'total variation损失函数权重')
#训练数据相关参数
tf.app.flags.DEFINE_integer( 'batch_size', 128, 'batch大小')
tf.app.flags.DEFINE_integer( 'epoch', 2, 'epoch个数')
#layers
tf.app.flags.DEFINE_list("content_layers", "vgg_16/conv3/conv3_3", "用于计算内容损失的layers")
tf.app.flags.DEFINE_list("style_layers", ["vgg_16/conv1/conv1_2",
"vgg_16/conv2/conv2_2"
"vgg_16/conv3/conv3_3"
"vgg_16/conv4/conv4_3"], "用于计算风格损失的layers")
tf.app.flags.DEFINE_string("checkpoint_exclude_scopes", "vgg_16/fc", "不从ckpt中恢复权重的层")
#learning_rate
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
FLAGS = tf.app.flags.FLAGS
读取数据模块:
######################
# read the data #
######################
def readImage(path, height, width, preprocess_fn):
"""
根据提供的Image路径,读取png/jpeg格式
key fun:
tf.read_file
tf.image.decode_png/jpeg
:param path: image path
:param height:
:param width:
:param preprocess_fn:
:return: a image
"""
# 如果是png格式的图片 isPng=True
isPng = path.lower().endswith('png')
img_data = tf.read_file(path)
if isPng:
image = tf.image.decode_png(img_data, channels=3)
else:
tf.image.decode_jpeg(img_data, channels=3)
return preprocess_fn(image, height, width)
def readImageBatch(batch_size, height, width, path, preprocess_fn, epochs=2, shuffle=True):
"""
该函数为在指定路径的COCO数据集中 使用文件读取队列读一个batch_size大小的image batch,预处理后返回,用于训练
key fun:
tf.train.string_input_producr
tf.WholeFileReader
reader.read
tf.train.batch
:param batch_size: MINI-BATCH
:param height: image height
:param width: image width
:param path: MS-coco数据集路径
:param preprocess_fn: 预处理函数
:param epochs: 迭代的epoch大小
:param shuffle: 是否混洗
:return: 一个image batch用于训练
"""
filenames = [join(path, f) for f in listdir(path) if isfile(join(path, f))]
if not shuffle:
filenames = sorted(filenames)
#如果是png格式的图片 isPng=True
isPng = filenames[0].lower().endswith('png')
filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs)
# If specified, string_input_producer produces each string from string_tensor num_epochs times before generating an OutOfRange error.
#return:A queue with the output strings. A QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.
reader = tf.WholeFileReader()
#A Reader that outputs the entire contents of a file as a value.
#To use, enqueue filenames in a Queue. The output of Read will be a filename (key) and the contents of that file (value).
_, img_data = reader.read(filename_queue)
#Decode a JPEG-encoded image to a uint8 tensor.
if isPng:
image = tf.image.decode_png(img_data, channels=3)
else:
image = tf.image.decode_jpeg(img_data, channels=3)
processed_image = preprocess_fn(image, height, width)
#Creates batches of tensors in tensors.
return tf.train.batch([processed_image], batch_size, dynamic_pad=True)
定义损坏函数模块:
######################
# define the loss #
######################
def gram(layer):
"""
计算格拉姆矩阵,计算方法参考论文
ket method:
tf.reshape
:param layer: 某一激活层的输出tensor,有多张feature map
:return:
"""
shape = tf.shape(layer)
num_images = shape[0]
width = shape[1]
height = shape[2]
num_filters = shape[3]
filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)
return grams
def get_styleImg_featuremaps():
"""
计算style-image经过VGG网络后的相应层的输出featuremaps
该函数中会保存Target style到resImg中
函数基本流程:
读取style_image文件,
预处理,输入到vgg中,
根据FLAGS.style_layers中的层,保存相应层的endpoints,即该层生成的featuresmaps
:return: 一个list,用于计算style-loss
"""
with tf.Graph().as_default():
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
# 读取style image
size = FLAGS.image_size
img_bytes = tf.read_file(FLAGS.style_image)
if FLAGS.style_image.lower().endswith('png'):
image = tf.image.decode_png(img_bytes)
else:
image = tf.image.decode_jpeg(img_bytes)
# Resize the shorter side to FLAGS.image_size
image=image_preprocessing_fn(image, size, size)
# 增加batch维度,因为network的输入需要一个4-D的tensor
images = tf.expand_dims(image, 0)
# a end_point dict.
# key:vgg_16/conv1/conv3_3
#计算style-image经过VGG网络后的相应层的输出features
_, endpoints_dict = network_fn(images, spatial_squeeze=False)
features = []
for layer in FLAGS.style_layers:
feature = endpoints_dict[layer]
feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension
features.append(feature)
with tf.Session() as sess:
# Restore variables for loss network.
init_func = get_init_fn()
init_func(sess)
if os.path.exists(FLAGS.res_file) is False:
os.makedirs(FLAGS.res_file)
# Indicate cropped style image path
save_file =FLAGS.res_file+ '/target_transfer_style_' + FLAGS.model_name + '.jpg'
# Write preprocessed style image to indicated path
with open(save_file, 'wb') as f:
target_image = image_unprocessing_fn(images[0, :])
value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
f.write(sess.run(value))
tf.logging.info('目标风格图片保存到: %s.' % save_file)
# Return the features those layers are use for measuring style loss.
return sess.run(features)
def get_style_loss(endpoints_dict, style_featuremaps, style_layers):
"""
计算style-loss
根据style_layers提供的层,计算格拉姆矩阵差的l2范数
:param endpoints_dict:
:param style_featuremaps:
:param style_layers:
:return:style_loss
"""
style_loss = 0
# style_loss_summary = {}
for style_fm, layer in zip(style_featuremaps, style_layers):
# 每一层会有很多的feature maps 使用split进行分割
generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
#格拉姆矩阵差的l2范数
layer_style_loss = tf.nn.l2_loss(gram(generated_images) - gram(style_fm)) * 2 / tf.to_float(size)
# style_loss_summary[layer] = layer_style_loss
style_loss += layer_style_loss
return style_loss
def get_content_loss(endpoints_dict, content_layers):
"""
根据endpoints_dict计算内容损失,损失函数的计算原理是某一个layer的输出差的l2范数
:param endpoints_dict:由generated_images和content_images经过VGG生成的endpoints_dict
:param content_layers:定义了需要计算content-loss的层
:return:内容损失
"""
content_loss = 0
for layer in content_layers:
generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
#Splits a tensor into sub tensors.
#把endpoints_dict[layer]沿0轴等分成两份,前一半是生成的风格images经过VGG的endpoints_dict,后一半是原始的内容image经过VGG的endpoints_dict
size = tf.size(generated_images)
#计算l2范数作为content-loss
content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper
return content_loss
def total_variation_loss(layer):
"""
To encourage spatial smoothness in the
output image
:param layer:
:return:
"""
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
#Extracts a slice from a tensor.
"""
tf.slice(
input_,
begin,
size,
name=None
)
"""
# 错位相减 作为损失函数平滑图像
y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
# 需要归一化 除以size
loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
return loss
模型训练部分:
######################
# train the model #
######################
def get_init_fn():
"""
恢复VGG loss model的权重
:return: 恢复权重的函数
"""
# 不恢复的权重列表
exclusions = []
if FLAGS.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
#根据名称来判断是否恢复
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
# 恢复vgg中需要用到的权重
return slim.assign_from_checkpoint_fn(
FLAGS.loss_model_file,
variables_to_restore,
ignore_missing_vars=True)
def main():
# training_path 为保存style-ckpt的路径
training_path = os.path.join(FLAGS.model_path, FLAGS.name)
if not(os.path.exists(training_path)):
os.makedirs(training_path)
with tf.Graph().as_default():
with tf.Session() as sess:
######得到processed_images processed_generated
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
processed_images = readImageBatch(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
#generated是由风格转换网络生成的image
generated = model.base_net(processed_images, training=True)
#Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
#generated:[batch_size,height,width,channels]
processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
#The list of Tensor objects unstacked from value.
]
processed_generated = tf.stack(processed_generated,axis=0)
_, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
# Log the structure of loss network
for key in endpoints_dict:
tf.logging.info(key)
########计算loss
style_featuremaps = get_styleImg_featuremaps(FLAGS)
content_loss = get_content_loss(endpoints_dict, FLAGS.content_layers)
style_loss = get_style_loss(endpoints_dict, style_featuremaps, FLAGS.style_layers)
tv_loss = total_variation_loss(generated) # use the unprocessed image
# 总loss,是优化目标
loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss
########准备训练:恢复权重
global_step = tf.Variable(0, name="global_step", trainable=False)
#构建可训练变量列表
variable_to_train = []
for variable in tf.trainable_variables():
if not(variable.name.startswith(FLAGS.loss_model)):
variable_to_train.append(variable)
#Adam最优化算法 训练op
train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss, global_step=global_step, var_list=variable_to_train)
#构建要恢复的变量列表
variables_to_restore = []
for v in tf.global_variables():
if not(v.name.startswith(FLAGS.loss_model)):
variables_to_restore.append(v)
#构建saver,用来恢复变量
saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Restore variables for loss network.
init_func = get_init_fn()
init_func(sess)
# Restore variables for training model if the checkpoint file exists.
last_file = tf.train.latest_checkpoint(training_path)
if last_file:
#tf.logging.info('Restoring model from {}'.format(last_file))
saver.restore(sess, last_file)
########开始训练
#核心: sess.run(train_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
try:
while not coord.should_stop():
# global_step 会自增一 跟迭代是一个概念
_, loss_t, step = sess.run([train_op, loss, global_step])
elapsed_time = time.time() - start_time
start_time = time.time()
if step % 10 == 0:
tf.logging.info('global_step: %d, total_Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
# checkpoint
if step % 1000 == 0:
saver.save(sess, os.path.join(training_path, FLAGS.model_name+'.ckpt'), global_step=step)
except tf.errors.OutOfRangeError:
saver.save(sess, os.path.join(training_path, FLAGS.model_name+'.ckpt-done'))
tf.logging.info('完成训练')
finally:
coord.request_stop()
coord.join(threads)