系列文章
(一)图像风格迁移
(二)快速图像风格转换
(三)快速图像风格转换代码解析
def net(image, training):
'''图像填充'''
image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT')
'''
(4, 276, 276, 3)
:params 4: 每组图像数量
:params [276, 276, 3] : 图像尺寸.
'''
print("image shape after padding: {}".format(image.shape))
with tf.variable_scope('conv1'):
'''
:params 3: 当前图像深度
:params 32: 下一网络层图像深度
:params 9:填充和滑动窗口内核.
:parmas 1: 滑动窗口平移步长
'''
'''[276, 276, 32]'''
conv1 = relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
print("conv1 shape: {}".format(conv1.shape))
with tf.variable_scope('conv2'):
'''[]'''
conv2 = relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
print("conv2 shape: {}".format(conv2.shape))
with tf.variable_scope('conv3'):
conv3 = relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
with tf.variable_scope('res1'):
res1 = residual(conv3, 128, 3, 1)
with tf.variable_scope('res2'):
res2 = residual(res1, 128, 3, 1)
with tf.variable_scope('res3'):
res3 = residual(res2, 128, 3, 1)
with tf.variable_scope('res4'):
res4 = residual(res3, 128, 3, 1)
with tf.variable_scope('res5'):
res5 = residual(res4, 128, 3, 1)
print("NN processed shape: {}".format(res5.get_shape()))
with tf.variable_scope('deconv1'):
# deconv1 = relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2)))
deconv1 = relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
print("deconv1 shape: {}".format(deconv1.shape))
with tf.variable_scope('deconv2'):
# deconv2 = relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2)))
deconv2 = relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
print("deconv2 shape: {}".format(deconv2.shape))
with tf.variable_scope('deconv3'):
# deconv_test = relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1)))
deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1)))
print("deconv3 shape: {}".format(deconv3.shape))
print("deconv3 value: {}".format(deconv3))
y = (deconv3 + 1) * 127.5
print("processed value: {}".format(y))
# Remove border effect reducing padding.
height = tf.shape(y)[1]
width = tf.shape(y)[2]
y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1]))
'''final y: Tensor("Slice_1:0", shape=(4, 256, 256, 3), dtype=float32)'''
print("final y: {}".format(y))
return y
import tensorflow as tf
import os
from preprocessing import preprocessing_factory
import reader
import model
import time
import base64
'''基本路径'''
basedir = os.path.abspath(os.path.dirname(__name__))
'''图像路径'''
image_path = "./process/xdqtest_resize.png"
height = 0
width = 0
'''读取图像,获取图像尺寸width和height'''
with open(image_path, 'rb') as img:
with tf.Session().as_default() as sess:
if image_path.lower().endswith('png'):
image = sess.run(tf.image.decode_png(img.read()))
else:
image = sess.run(tf.image.decode_jpeg(img.read()))
height = image.shape[0]
width = image.shape[1]
if __name__ == "__main__":
with tf.Session() as sess:
'''处理图像的闭包函数'''
image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
"vgg_16",
is_training=False)
'''读取图像'''
image = reader.get_image(image_path, height, width, image_preprocessing_fn)
'''增加图像维度'''
image = tf.expand_dims(image, 0)
'''建立网络结构'''
generated = model.net(image, training=False)
'''载入网络的全局变量'''
saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
'''
初始化全局和本地变量,其中,
全局变量为网络中的变量,
本地变量为session中的变量.
'''
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
'''模型路径:只包含训练的参数'''
model_path = basedir + "/models/mosaic.ckpt-done"
print("model path: {}".format(model_path))
'''载入模型参数'''
saver.restore(sess, model_path)
'''模型参数读取:获取模型ckpt的图结构graph_def部分'''
read_graph = sess.graph.as_graph_def()
for node in read_graph.node:
print("node name: {}----->node operation: {}".format(node.name, node.op))
'''保存图像路径'''
generated_file = 'generated/processed.jpg' +
if os.path.exists('generated') is False:
os.makedir('generated')
'''保存图像'''
with open(generated_file, 'wb') as img:
start_time = time.time()
img.write(sess.run(tf.image.encode_jpeg(generated)))
end_time = time.time()
...
'''卷积层'''
node name: conv1/conv/truncated_normal/shape----->node operation: Const
node name: conv1/conv/truncated_normal/mean----->node operation: Const
node name: conv1/conv/truncated_normal/stddev----->node operation: Const
...
'''参差层'''
node name: res1/residual/conv/MirrorPad----->node operation: MirrorPad
node name: res1/residual/conv/conv----->node operation: Conv2D
node name: res1/residual/Relu----->node operation: Relu
node name: res1/residual/Equal----->node operation: Equal
'''图像恢复'''
node name: deconv1/conv_transpose/Shape----->node operation: Const
...
node name: deconv1/conv_transpose/conv/weight----->node operation: VariableV2
node name: deconv1/conv_transpose/conv/weight/Assign----->node operation: Assign
node name: deconv1/conv_transpose/conv/weight/read----->node operation: Identity
'''保存模型'''
node name: save/Const----->node operation: Const
...
node name: save/Assign_15----->node operation: Assign
node name: save/restore_all----->node operation: NoOp
node name: init----->node operation: NoOp
node name: init_1----->node operation: NoOp
slim-vgg
网络用于提取图像的内容和风格;(1) 载入模型前先确认模型类型即该模型中包含的是参数还是结构,若模型只含有参数,则载入模型前需要先建立网络结构;
(2) 图像风格迁移训练的网络是新建的神经网络,不是slim
的vgg
网络,风格网络结构中最后的三层是将深度处理的图像进行去深度化,以获取正常三通道(RGB)的图像;
(3) 使用训练模型处理抓换图片,直接将图像数据传入神经网络即可,没有严格使用sess.run(variable, feed_dict={x: x_data})
.
[参考文献]
[1]https://blog.csdn.net/Xin_101/article/details/88883977
[2]https://blog.csdn.net/Xin_101/article/details/87854371
[3]https://blog.csdn.net/Xin_101/article/details/88581250
[4]https://blog.csdn.net/Xin_101/article/details/84981890