本文目录:
- Introduction
- Related work
- Methods
- Gram 矩阵
- Batch Normalization
Introduction
不久前,一个名叫Prisma的APP在微博和朋友圈火了起来。Prisma是个能够将图像风格转换为艺术风格的APP,它能够实现如下转换:
除了引起大众的好奇心外,业内人士也纷纷猜测Prisma是如何做到实现快速的图像风格转换。此前,在Gatys的论文
在文中我将讲解Prisma是如何实现实时风格转换的。本文内容基于Fei Fei Li团队的
系列文章目录如下:
- 梵高眼中的世界(一)实时图像风格转换简介
- 梵高眼中的世界(二)基于perceptual损失的网络
- 梵高眼中的世界(三)实现与改进
Related work
在进行图像风格转换时,我们需要一张风格图像style image和一张内容图像content image。我们构造一个网络衡量生成图像与style image以及content image的loss,再通过训练减小loss得到最终图像。
在Gatys的方法中,他使用了如下图所示的方法:
上图最左边是风格图像,梵高的《星夜》;最右边是内容图像。
算法步骤如下:
生成了一张白噪声图像作为初始图像。
将风格图像,内容图像,初始图像分别通过一个预训练的VGG-19网络,得到某些层的输出。这里的“某些层”是经过实验得出的,是使得输出图像最佳的层数。
-
计算内容损失函数:
其中Pl_ij是原始图像在第l层位置j与第i个filter卷积后的输出,Fl_ij是相应的生成图像的输出。
计算风格损失函数:
风格损失函数与图像有些不同,在这里我们不直接使用某些层卷积后的输出,而是计算输出的Gram矩阵,再用于上式风格损失的计算:
5.计算总损失
此时我们可以通过梯度下降算法对初始化的白噪声图像进行训练,得到最终的风格转换图像。
Gatys的算法缺点是一次只能训练出一张图。我们希望得到一个前馈的神经网络,对于每一张内容图像,只需要通过这个前馈神经网络,就能快速得到风格转换图像。
Methods
在这里只对Gram matrix以及Batch Normalization进行讲解,具体实现细节请阅读原文。
Gram matrix
Gram matrix 计算如下:
上式的意思为,G^l_i,j意味着第l层特征图i和j的内积。同理可表示为:
在论文中,作者用高维的特征图相关性来表示图像风格。上式矩阵的对角线表示每一个特征图自身的信息,其余元素表示了不同特征图之间的信息。
Gram matrix的tensorflow实现如下:
def gram_matrix(x):
'''
Args:
x: Tensor with shape [batch size, length, width, channels]
Return:
Tensor with shape [channels, channels]
'''
bs, l, w, c = x.get_shape()
size = l*w*c
x = tf.reshape(x, (bs, l*w, c))
x_t = tf.transpose(x, perm=[0,2,1])
return tf.matmul(x_t, x)/size
Batch Normalization
Batch Normalization 最早由Google在ICML2015的论文
其算法如下:
这个算法看上去有点复杂,但直观上很好理解:
对于一个mini-batch里面的值x_i,我们计算平均值 μ和方差σ。对于每一个x_i,我们对其进行z-score归一化,得到平均值为0,标准差为1的数据。式子中的ε是一个很小的偏差值,防止出现除以0的情况。实现中可以取ε=1e-3。在对数据进行归一化后,BN算法再进行“scale and shift”,将数据还原成原来的输入。
Batch Normalization是为了解决Internal Covariate Shift问题而提出。
Batch Normalization在Tensorflow下的实现:
from tensorflow.contrib.layers import batch_norm
def batch_norm_layer(x, is_training, scope):
bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=True,
reuse=None,
trainable=True,
scope=scope)
bn_test = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=False,
reuse=True,
trainable=True,
scope=scope)
bn = tf.cond(is_training, lambda: bn_train, lambda: bn_test)
return bn
注意其中is_training是一个placeholder。