【创新实训】风格迁移功能探索与实现(五)eval_model.py 验证模型的编写

eval_model.py可以参考train_model.py编写,因此差不多

大体思路就是:

1.读入图片并进行预处理
2.从保存的风格ckpt文件中恢复模型权重(注意不需要vgg)

3.将图片tensor输入到net中,得到转换后的image并保存


代码及注释:

# coding: utf-8
from __future__ import print_function
import tensorflow as tf
from preprocessing import preprocessing_factory
import reader
import model
import time
import os
"""
编码思路:
1.读入图片并进行预处理
2.从保存的风格ckpt文件中恢复模型权重(注意不需要vgg)
3.将图片tensor输入到net中,得到转换后的image并保存
"""

######################
# define the parameter#
######################
tf.app.flags.DEFINE_string('loss_model', 'vgg_16', '损失网络模型名 ')
tf.app.flags.DEFINE_string('loss_model_file', 'loss_model_ckpt/vgg_16.ckpt', '损失网络ckpt文件路径 ')
tf.app.flags.DEFINE_integer('image_size', 256, '图像大小')

#model的ckpt相关
tf.app.flags.DEFINE_string("model_path", "transfer_model_ckpt", "风格ckpt文件路径")
tf.app.flags.DEFINE_string("model_name", "candy", "风格名")
tf.app.flags.DEFINE_string("model_file", "models.ckpt", "风格ckpt文件名")

#内容图片与风格图片
tf.app.flags.DEFINE_string("image_file", "srcImg/test.jpg", "输入模型的图片路径")
tf.app.flags.DEFINE_string("res_file", "resImg", "模型输出的图片保存目录")
tf.app.flags.DEFINE_string("res_image", "res.jpg", "模型输出的图片保存目录")
tf.app.flags.DEFINE_string("style_image", "styleImg/candy.jpg", "风格图片的路径")

#损失函数权重参数
tf.app.flags.DEFINE_float('content_weight', 1.0, '内容损失函数权重')
tf.app.flags.DEFINE_float('style_weight', 100.0, '风格损失函数权重')
tf.app.flags.DEFINE_float('tv_weight', 0.5, 'total variation损失函数权重')

#训练数据相关参数
tf.app.flags.DEFINE_integer( 'batch_size', 128, 'batch大小')
tf.app.flags.DEFINE_integer( 'epoch', 2, 'epoch个数')

#layers
tf.app.flags.DEFINE_list("content_layers", "vgg_16/conv3/conv3_3", "用于计算内容损失的layers")
tf.app.flags.DEFINE_list("style_layers", ["vgg_16/conv1/conv1_2",
                                          "vgg_16/conv2/conv2_2"
                                          "vgg_16/conv3/conv3_3"
                                          "vgg_16/conv4/conv4_3"], "用于计算风格损失的layers")
tf.app.flags.DEFINE_string("checkpoint_exclude_scopes", "vgg_16/fc", "不从ckpt中恢复权重的层")

#learning_rate
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')

FLAGS = tf.app.flags.FLAGS

height = 0
width = 0
def main(_):

    #指定image路径,读取图片获取宽和高
    FLAGS.model_file=FLAGS.model_path+ FLAGS.model_file
    with open(FLAGS.image_file, 'rb') as img:
        with tf.Session().as_default() as sess:
            if FLAGS.image_file.lower().endswith('png'):
                image = sess.run(tf.image.decode_png(img.read()))
            else:
                image = sess.run(tf.image.decode_jpeg(img.read()))
            height = image.shape[0]
            width = image.shape[1]

    with tf.Graph().as_default():
        with tf.Session().as_default() as sess:

            # 读入image数据,并进行预处理
            image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)

            # 增加batch维度
            image = tf.expand_dims(image, 0)

            #转换网络模型的输出,真正运行是在后面恢复权重以后
            generated = model.net(image, training=False)
            generated = tf.cast(generated, tf.uint8)#转换数据格式

            # 去除batch维度
            generated = tf.squeeze(generated, [0])


            saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            #从已训练的风格转换模型的ckpt文件中恢复权重
            FLAGS.model_file = os.path.abspath(FLAGS.model_file)
            saver.restore(sess, FLAGS.model_file)

            generated_file = FLAGS.res_file+FLAGS.res_img
            if os.path.exists(FLAGS.res_file) is False:
                os.makedirs(FLAGS.res_file)

            # 生成图片
            with open(generated_file, 'wb') as img:
                start_time = time.time()
                img.write(sess.run(tf.image.encode_jpeg(generated)))
                end_time = time.time()
                tf.logging.info('Elapsed time: %fs' % (end_time - start_time))

                tf.logging.info('Done. Please check %s.' % generated_file)


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()

你可能感兴趣的:(项目创新实训)