在现实生活当中,除了语言之间的翻译之外,我们也经常会遇到各种图像的“翻译”任务,即给定一张图像,生成目标图像,常见的场景有:图像风格迁移、图像超级分辨率、图像上色、图像去马赛克等。而在现实生活当中,图像翻译任务更常见的场景可能是图像的修图与美化,因此,本文将准备介绍另一个新的图像翻译任务——AI修图,即给定一张图像,让机器自动对该图像进行修图,从而达到一个更加美化的效果。
本文将利用GAN网络中一个比较经典的模型,即pix2pix模型,该网络采用一种完全监督的方法,即利用完全配对的输入和输出图像训练模型,通过训练好的模型将输入的图像生成指定任务的目标图像。目前该方法是图像翻译任务中完全监督方法里面效果和通用性最好的一个模型,在介绍这个模型的结构之前,可以先来看下作者利用这个网络所做的一些有趣的实验:
具体效果如下图所示 :
pix2pix网络是GAN网络中的一种,主要是采用cGAN网络的结构,它依然包括了一个生成器和一个判别器。生成器采用的是一个U-net的结构,其结构有点类似Encoder-decoder,总共包含15层,分别有8层卷积层作为encoder,7层反卷积层(关于反卷积层的概念可以参考这篇博客:反卷积原理不可多得的好文)作为decoder,与传统的encoder-decoder不同的是引入了一个叫做“skip-connect”的技巧,即每一层反卷积层的输入都是:前一层的输出+与该层对称的卷积层的输出,从而保证encoder的信息在decoder时可以不断地被重新记忆,使得生成的图像尽可能保留原图像的一些信息。具体如下图所示:
对于判别器,pix2pix采用的是一个6层的卷积网络,其思想与传统的判别器类似,只是有以下两点比较特别的地方:
其具体的结构如下图所示:
pix2pix的损失函数除了标准的GAN网络的损失函数之外,还引入了的损失函数。记为输入的图像,为真实图像(输出图像),为生成器,为判别器,则标准的GAN网络的损失函数为:
对G施加惩罚,即:
因此,最终GAN网络的损失函数为:
这样一来,标准的GAN损失负责捕捉图像高频特征,而损失则负责捕捉低频特征,使得生成结果既真实且清晰。
本文利用pix2pix进行AI修图,采用的框架是tensorflow实现。首先是将输入图像和真实图像(输出图像)分别压缩至256*256的规格,并将两者拼接在一起,形式如下:
其中,左侧为修图前的原图,右侧为人工修图的结果,总共采集了1700对这样的图像作为模型的训练集,模型的主要代码模块如下:
import tensorflow as tf
import numpy as np
from PIL import Image
from data_loader import get_batch_data
import os
import re
class pix2pix(object):
def __init__(self, sess, batch_size, L1_lambda):
"""
:param sess: tf.Session
:param batch_size: batch_size. [int]
:param L1_lambda: L1_loss lambda. [int]
"""
self.sess = sess
self.k_initializer = tf.random_normal_initializer(0, 0.02)
self.g_initializer = tf.random_normal_initializer(1, 0.02)
self.L1_lambda = L1_lambda
self.bulid_model()
def bulid_model(self):
"""
初始化模型
:return:
"""
# init variable
self.x_ = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3], name='x')
self.y_ = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3], name='y')
# generator
self.g = self.generator(self.x_)
# discriminator
self.d_real = self.discriminator(self.x_, self.y_)
self.d_fake = self.discriminator(self.x_, self.g, reuse=True)
# loss
self.loss_g, self.loss_d = self.loss(self.d_real, self.d_fake, self.y_, self.g)
# summary
tf.summary.scalar("loss_g", self.loss_g)
tf.summary.scalar("loss_d", self.loss_d)
self.merged = tf.summary.merge_all()
# vars
self.vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
self.vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
# saver
self.saver = tf.train.Saver()
def discriminator(self, x, y, reuse=None):
"""
判别器
:param x: 输入图像. [tensor]
:param y: 目标图像. [tensor]
:param reuse: reuse or not. [boolean]
:return:
"""
with tf.variable_scope('discriminator', reuse=reuse):
x = tf.concat([x, y], axis=3)
h0 = self.lrelu(self.d_conv(x, 64, 2)) # 128 128 64
h0 = self.d_conv(h0, 128, 2)
h0 = self.lrelu(self.batch_norm(h0)) # 64 64 128
h0 = self.d_conv(h0, 256, 2)
h0 = self.lrelu(self.batch_norm(h0)) # 32 32 256
h0 = self.d_conv(h0, 512, 1)
h0 = self.lrelu(self.batch_norm(h0)) # 31 31 512
h0 = self.d_conv(h0, 1, 1) # 30 30 1
h0 = tf.nn.sigmoid(h0)
return h0
def generator(self, x):
"""
生成器
:param x: 输入图像. [tensor]
:return: h0,生成的图像. [tensor]
"""
with tf.variable_scope('generator', reuse=None):
layers = []
h0 = self.g_conv(x, 64)
layers.append(h0)
for filters in [128, 256, 512, 512, 512, 512, 512]: # [128, 256, 512, 512, 512, 512, 512]
h0 = self.lrelu(layers[-1])
h0 = self.g_conv(h0, filters)
h0 = self.batch_norm(h0)
layers.append(h0)
encode_layers_num = len(layers) # 8
for i, filters in enumerate([512, 512, 512, 512, 256, 128, 64]): # [512, 512, 512, 512, 256, 128, 64]
skip_layer = encode_layers_num - i - 1
if i == 0:
inputs = layers[-1]
else:
inputs = tf.concat([layers[-1], layers[skip_layer]], axis=3)
h0 = tf.nn.relu(inputs)
h0 = self.g_deconv(h0, filters)
h0 = self.batch_norm(h0)
if i < 3:
h0 = tf.nn.dropout(h0, keep_prob=0.5)
layers.append(h0)
inputs = tf.concat([layers[-1], layers[0]], axis=3)
h0 = tf.nn.relu(inputs)
h0 = self.g_deconv(h0, 3)
h0 = tf.nn.tanh(h0, name='g')
return h0
def loss(self, d_real, d_fake, y, g):
"""
定义损失函数
:param d_real: 真实图像判别器的输出. [tensor]
:param d_fake: 生成图像判别器的输出. [tensor]
:param y: 目标图像. [tensor]
:param g: 生成图像. [tensor]
:return: loss_g, loss_d, 分别对应生成器的损失函数和判别器的损失函数
"""
loss_d_real = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_real, tf.ones_like(d_real)))
loss_d_fake = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_fake, tf.zeros_like(d_fake)))
loss_d = loss_d_real + loss_d_fake
loss_g_gan = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_fake, tf.ones_like(d_fake)))
loss_g_l1 = tf.reduce_mean(tf.abs(y - g))
loss_g = loss_g_gan + loss_g_l1 * self.L1_lambda
return loss_g, loss_d
def lrelu(self, x, leak=0.2):
"""
lrelu函数
:param x:
:param leak:
:return:
"""
return tf.maximum(x, leak * x)
def d_conv(self, inputs, filters, strides):
"""
判别器卷积层
:param inputs: 输入. [tensor]
:param filters: 输出通道数. [int]
:param strides: 卷积核步伐. [int]
:return:
"""
padded = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')
return tf.layers.conv2d(padded,
kernel_size=4,
filters=filters,
strides=strides,
padding='valid',
kernel_initializer=self.k_initializer)
def g_conv(self, inputs, filters):
"""
生成器卷积层
:param inputs: 输入. [tensor]
:param filters: 输出通道数. [int]
:return:
"""
return tf.layers.conv2d(inputs,
kernel_size=4,
filters=filters,
strides=2,
padding='same',
kernel_initializer=self.k_initializer)
def g_deconv(self, inputs, filters):
"""
生成器反卷积层
:param inputs: 输入. [tensor]
:param filters: 输出通道数. [int]
:return:
"""
return tf.layers.conv2d_transpose(inputs,
kernel_size=4,
filters=filters,
strides=2,
padding='same',
kernel_initializer=self.k_initializer)
def batch_norm(self, inputs):
"""
批标准化函数
:param inputs: 输入. [tensor]
:return:
"""
return tf.layers.batch_normalization(inputs,
axis=3,
epsilon=1e-5,
momentum=0.1,
training=True,
gamma_initializer=self.g_initializer)
def sigmoid_cross_entropy_with_logits(self, x, y):
"""
交叉熵函数
:param x:
:param y:
:return:
"""
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,
labels=y)
def train(self, images, epoch, batch_size):
"""
训练函数
:param images: 图像路径列表. [list]
:param epoch: 迭代次数. [int]
:param batch_size: batch_size. [int]
:return:
"""
# optimizer
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optim_d = tf.train.AdamOptimizer(learning_rate=0.0002,
beta1=0.5
).minimize(self.loss_d, var_list=self.vars_d)
optim_g = tf.train.AdamOptimizer(learning_rate=0.0002,
beta1=0.5
).minimize(self.loss_g, var_list=self.vars_g)
# init variables
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
self.writer = tf.summary.FileWriter("./log", self.sess.graph)
# training
for i in range(epoch):
# 获取图像列表
print("Epoch:%d/%d:" % ((i + 1), epoch))
batch_num = int(np.ceil(len(images) / batch_size))
# batch_list = np.array_split(random.sample(images, len(images)), batch_num)
batch_list = np.array_split(images, batch_num)
# 训练生成器和判别器
for j in range(len(batch_list)):
batch_x, batch_y = get_batch_data(batch_list[j])
_, loss_d = self.sess.run([optim_d, self.loss_d],
feed_dict={self.x_: batch_x, self.y_: batch_y})
_, loss_g = self.sess.run([optim_g, self.loss_g],
feed_dict={self.x_: batch_x, self.y_: batch_y})
print("%d/%d -loss_d:%.4f -loss_g:%.4f" % ((j + 1), len(batch_list), loss_d, loss_g))
# 保存损失值
summary = self.sess.run(self.merged,
feed_dict={self.x_: batch_x, self.y_: batch_y})
self.writer.add_summary(summary, global_step=i)
# 保存模型,每10次保存一次
if (i + 1) % 10 == 0:
self.saver.save(self.sess, './checkpoint/epoch_%d.ckpt' % (i + 1))
# 测试,每循环一次测试一次
if (i + 1) % 1 == 0:
# 对训练集最后一张图像进行测试
train_save_path = os.path.join('./result/train',
re.sub('.jpg',
'',
os.path.basename(images[-1])
) + '_' + str(i + 1) + '.jpg'
)
train_g = self.sess.run(self.g,
feed_dict={self.x_: batch_x}
)
train_g = 255 * (np.array(train_g[0] + 1) / 2)
im = Image.fromarray(np.uint8(train_g))
im.save(train_save_path)
# 对验证集进行测试
img = np.zeros((256, 256 * 3, 3))
val_img_path = np.array(['./data/val/color/10901.jpg'])
batch_x, batch_y = get_batch_data(val_img_path)
val_g = self.sess.run(self.g, feed_dict={self.x_: batch_x})
img[:, :256, :] = 255 * (np.array(batch_x + 1) / 2)
img[:, 256:256 * 2, :] = 255 * (np.array(batch_y + 1) / 2)
img[:, 256 * 2:, :] = 255 * (np.array(val_g[0] + 1) / 2)
img = Image.fromarray(np.uint8(img))
img.save('./result/val/10901_%d.jpg' % (i + 1))
def save_img(self, g, data, save_path):
"""
保存图像
:param g: 生成的图像. [array]
:param data: 测试数据. [list]
:param save_path: 保存路径. [str]
:return:
"""
if len(data) == 1:
img = np.zeros((256, 256 * 2, 3))
img[:, :256, :] = 255* (np.array(data[0] + 1) / 2)
img[:, 256:, :] = 255 * (np.array(g[0] + 1) / 2)
else:
img = np.zeros((256, 256 * 3, 3))
img[:, :256, :] = 255 * (np.array(data[0] + 1) / 2)
img[:, 256:256 * 2, :] = 255 * (np.array(data[1] + 1) / 2)
img[:, 256 * 2:, :] = 255 * (np.array(g[0] + 1) / 2)
im = Image.fromarray(np.uint8(img))
im.save(os.path.join('./result/test', os.path.basename(save_path)))
def test(self, images, batch_size=1, save_path=None, mode=None):
"""
测试函数
:param images: 测试图像列表. [list]
:param batch_size: batch_size. [int]
:param save_path: 保存路径
:return:
"""
# init variables
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
# load model
self.saver.restore(self.sess,
tf.train.latest_checkpoint('./checkpoint')
)
# test
if mode != 'orig':
for j in range(len(images)):
batch_x, batch_y = get_batch_data(np.array([images[j]]))
g = self.sess.run(self.g, feed_dict={self.x_: batch_x})
if save_path == None:
self.save_img(g,
data=[batch_x[0], batch_y[0]],
save_path=images[j]
)
else:
self.save_img(g,
data=[batch_x[0], batch_y[0]],
save_path=save_path
)
else:
for j in range(len(images)):
batch_x = get_batch_data(np.array([images[j]]), mode=mode)
g = self.sess.run(self.g, feed_dict={self.x_: batch_x})
batch_x = 255 * (np.array(batch_x[0] + 1) / 2)
g = 255 * (np.array(g[0] + 1) / 2)
img = np.hstack((batch_x, g))
im = Image.fromarray(np.uint8(img))
im.save(os.path.join('./result/test', os.path.basename(images[j])))
最终经过训练40个epoch后,判别器和生成器的损失函数均达到了平衡状态,因此,对训练过程进行了终止,如下图所示:
利用训练40个epoch后的模型对测试集进行测试,得到模型最终的效果如下:
其中,从左到右分别对应原图、人工修图、AI修图,可以发现,AI修图的结果会使得色彩更加艳丽,并且修图的效果比人工修图更加真实一点,本文也利用训练好的模型对任意规格的高清图像进行了测试,得到效果如下:
左边是从百度上直接下载下来的两张风景图,右边是本文训练出来的模型修图后的结果,可以发现,虽然这两张原图的已经是经过p图之后的结果,但是用AI修图后在亮度、色彩对比度等方面还是有进一步的提升,模型的泛化效果还是蛮不错滴!
最后,大概讲一下模型的缺点吧,pix2pix虽然通用性很强,但是模型能否收敛对数据的质量要求很高,如果数据质量比较差的话,则训练出来的模型效果就比较差,笔者最开始没有对数据进行清洗,因此训练出来的效果比较模糊,另外,pix2pix要求必须是严格的配对数据,因此,对数据的要求更加苛刻,如果对这方面比较感兴趣的朋友,也可以考虑一下非监督学习方面的模型,比如WESPE模型等。以下是原论文的地址和作者的pytorch实现:
招聘信息:
熊猫书院算法工程师:
https://www.lagou.com/jobs/4842081.html
希望对深度学习算法感兴趣的小伙伴们可以加入我们,一起改变教育!