基于深度学习的图像的风格迁移创新实训2

这周我的工作主要负责封装风格迁移网络对于外界的接口,经过一周的训练,我们已经得到了七种风格的网络模型,分别保存在基于深度学习的图像的风格迁移创新实训2_第1张图片

ckpt文件当中。首先判断用户选择的风格样式,然后调用相应的tensorflow调用相关的网络模型,将用户传入的图片经过迁移、保存,返回给用户。

主要代码如下:

def main(argv):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    style = argv[1]
    rawImg = argv[2]
    genImg = argv[3]
    print(style,rawImg,genImg)
    tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'You can view all the support models in nets/nets_factory.py')
    tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.')

    model_path = "E:\\programming\\PYworkbench\\Style-Transformer-Website\\trained_models\\"

    if style=='1':
        tf.app.flags.DEFINE_string("model_file", model_path+"shuimo.ckpt-done", "")
    elif style=='2':
        tf.app.flags.DEFINE_string("model_file", model_path + "cubist.ckpt-6000", "")
    elif style =='3':
        tf.app.flags.DEFINE_string("model_file", model_path + "denoised_starry.ckpt-done", "")
    elif style =='4':
        tf.app.flags.DEFINE_string("model_file", model_path + "feathers.ckpt-done", "")
    elif style=='5':
        tf.app.flags.DEFINE_string("model_file", model_path + "mosaic.ckpt-done", "")
    elif style=='6':
        tf.app.flags.DEFINE_string("model_file", model_path + "scream.ckpt-done", "")
    elif style=='7':
        tf.app.flags.DEFINE_string("model_file", model_path + "udnie.ckpt-done", "")
    elif style=='8':
        tf.app.flags.DEFINE_string("model_file", model_path + "wave.ckpt-done", "")
    elif style == '9':
        tf.app.flags.DEFINE_string("model_file", model_path + "jianzhi.ckpt-4000", "")



    tf.app.flags.DEFINE_string("image_file",rawImg, "")

    FLAGS = tf.app.flags.FLAGS

    with open(FLAGS.image_file, 'rb') as img:
        with tf.Session().as_default() as sess:
            if FLAGS.image_file.lower().endswith('png'):
                image = sess.run(tf.image.decode_png(img.read()))
            else:
                image = sess.run(tf.image.decode_jpeg(img.read()))
            height = image.shape[0]
            width = image.shape[1]
    tf.logging.info('Image size: %dx%d' % (width, height))

    with tf.Graph().as_default():
        with tf.Session().as_default() as sess:

            # Read image data.
            image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)

            # Add batch dimension
            image = tf.expand_dims(image, 0)

            generated = model.net(image, training=False)
            generated = tf.cast(generated, tf.uint8)

            # Remove batch dimension
            generated = tf.squeeze(generated, [0])

            # Restore model variables.
            saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            # Use absolute path
            FLAGS.model_file = os.path.abspath(FLAGS.model_file)
            saver.restore(sess, FLAGS.model_file)

            # Make sure 'generated' directory exists.
            generated_file = genImg
            if os.path.exists('generated') is False:
                os.makedirs('generated')

            # Generate and write image data to file.
            with open(generated_file, 'wb') as img:
                start_time = time.time()
                img.write(sess.run(tf.image.encode_jpeg(generated)))
                if(style == '1'):
                    str = 'python Sky_segment_postProcessing/sky_postprocessing.py '\
                          + rawImg + ' ' +generated_file
                    os.system(str)
                end_time = time.time()
                print('Elapsed time: %fs' % (end_time - start_time))
                print('Done. Please check %s.' % generated_file)

你可能感兴趣的:(风格迁移)