eval_model.py可以参考train_model.py编写,因此差不多
大体思路就是:
1.读入图片并进行预处理3.将图片tensor输入到net中,得到转换后的image并保存
代码及注释:
# coding: utf-8
from __future__ import print_function
import tensorflow as tf
from preprocessing import preprocessing_factory
import reader
import model
import time
import os
"""
编码思路:
1.读入图片并进行预处理
2.从保存的风格ckpt文件中恢复模型权重(注意不需要vgg)
3.将图片tensor输入到net中,得到转换后的image并保存
"""
######################
# define the parameter#
######################
tf.app.flags.DEFINE_string('loss_model', 'vgg_16', '损失网络模型名 ')
tf.app.flags.DEFINE_string('loss_model_file', 'loss_model_ckpt/vgg_16.ckpt', '损失网络ckpt文件路径 ')
tf.app.flags.DEFINE_integer('image_size', 256, '图像大小')
#model的ckpt相关
tf.app.flags.DEFINE_string("model_path", "transfer_model_ckpt", "风格ckpt文件路径")
tf.app.flags.DEFINE_string("model_name", "candy", "风格名")
tf.app.flags.DEFINE_string("model_file", "models.ckpt", "风格ckpt文件名")
#内容图片与风格图片
tf.app.flags.DEFINE_string("image_file", "srcImg/test.jpg", "输入模型的图片路径")
tf.app.flags.DEFINE_string("res_file", "resImg", "模型输出的图片保存目录")
tf.app.flags.DEFINE_string("res_image", "res.jpg", "模型输出的图片保存目录")
tf.app.flags.DEFINE_string("style_image", "styleImg/candy.jpg", "风格图片的路径")
#损失函数权重参数
tf.app.flags.DEFINE_float('content_weight', 1.0, '内容损失函数权重')
tf.app.flags.DEFINE_float('style_weight', 100.0, '风格损失函数权重')
tf.app.flags.DEFINE_float('tv_weight', 0.5, 'total variation损失函数权重')
#训练数据相关参数
tf.app.flags.DEFINE_integer( 'batch_size', 128, 'batch大小')
tf.app.flags.DEFINE_integer( 'epoch', 2, 'epoch个数')
#layers
tf.app.flags.DEFINE_list("content_layers", "vgg_16/conv3/conv3_3", "用于计算内容损失的layers")
tf.app.flags.DEFINE_list("style_layers", ["vgg_16/conv1/conv1_2",
"vgg_16/conv2/conv2_2"
"vgg_16/conv3/conv3_3"
"vgg_16/conv4/conv4_3"], "用于计算风格损失的layers")
tf.app.flags.DEFINE_string("checkpoint_exclude_scopes", "vgg_16/fc", "不从ckpt中恢复权重的层")
#learning_rate
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
FLAGS = tf.app.flags.FLAGS
height = 0
width = 0
def main(_):
#指定image路径,读取图片获取宽和高
FLAGS.model_file=FLAGS.model_path+ FLAGS.model_file
with open(FLAGS.image_file, 'rb') as img:
with tf.Session().as_default() as sess:
if FLAGS.image_file.lower().endswith('png'):
image = sess.run(tf.image.decode_png(img.read()))
else:
image = sess.run(tf.image.decode_jpeg(img.read()))
height = image.shape[0]
width = image.shape[1]
with tf.Graph().as_default():
with tf.Session().as_default() as sess:
# 读入image数据,并进行预处理
image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)
# 增加batch维度
image = tf.expand_dims(image, 0)
#转换网络模型的输出,真正运行是在后面恢复权重以后
generated = model.net(image, training=False)
generated = tf.cast(generated, tf.uint8)#转换数据格式
# 去除batch维度
generated = tf.squeeze(generated, [0])
saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
#从已训练的风格转换模型的ckpt文件中恢复权重
FLAGS.model_file = os.path.abspath(FLAGS.model_file)
saver.restore(sess, FLAGS.model_file)
generated_file = FLAGS.res_file+FLAGS.res_img
if os.path.exists(FLAGS.res_file) is False:
os.makedirs(FLAGS.res_file)
# 生成图片
with open(generated_file, 'wb') as img:
start_time = time.time()
img.write(sess.run(tf.image.encode_jpeg(generated)))
end_time = time.time()
tf.logging.info('Elapsed time: %fs' % (end_time - start_time))
tf.logging.info('Done. Please check %s.' % generated_file)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()