风格迁移

简介

使用TensorFlow实现快速风格迁移(fast neural style transfer), 参考论文[1]

原理

根据内容图片和风格图片优化输入图片,使得内容损失和风格损失尽可能小,快速图像风格迁移网络结构如下所示:


image

其中风格图片时固定的,而内容是可变的,即将任意输入图片转换为指定风格的风格图片。

  • 转换网络:参数要训练,将内容图片转换成迁移图片
  • 损失网络:计算迁移图片和风格图片之间的风格损失,以及迁移图片和原始图片之间的内容损失。
    经过训练后,转换网络所生成的迁移图片,在内容上和输入的内容图片相似,在风格上和指定的风格图片相似。

代码

1、使用vgg19作为损失网络:

def conv_(inputs, w, b):
    return tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME") + b

def max_pooling(inputs):
    return tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")

def vggnet(inputs, vgg_path='/home/luodan/project/fast_neural_style/Conditional-Instance-Norm-for-n-Style-Transfer/vgg_para/'):
    inputs = tf.reverse(inputs, [-1]) - np.array([103.939, 116.779, 123.68])
    para = np.load(vgg_path+"vgg16.npy", encoding="latin1").item()
    F = {}
    inputs = relu(conv_(inputs, para["conv1_1"][0], para["conv1_1"][1]))
    inputs = relu(conv_(inputs, para["conv1_2"][0], para["conv1_2"][1]))
    F["conv1_2"] = inputs
    inputs = max_pooling(inputs)
    inputs = relu(conv_(inputs, para["conv2_1"][0], para["conv2_1"][1]))
    inputs = relu(conv_(inputs, para["conv2_2"][0], para["conv2_2"][1]))
    F["conv2_2"] = inputs
    inputs = max_pooling(inputs)
    inputs = relu(conv_(inputs, para["conv3_1"][0], para["conv3_1"][1]))
    inputs = relu(conv_(inputs, para["conv3_2"][0], para["conv3_2"][1]))
    inputs = relu(conv_(inputs, para["conv3_3"][0], para["conv3_3"][1]))
    F["conv3_3"] = inputs
    inputs = max_pooling(inputs)
    inputs = relu(conv_(inputs, para["conv4_1"][0], para["conv4_1"][1]))
    inputs = relu(conv_(inputs, para["conv4_2"][0], para["conv4_2"][1]))
    inputs = relu(conv_(inputs, para["conv4_3"][0], para["conv4_3"][1]))
    F["conv4_3"] = inputs
    return F

2、风格网络

def transfer(image):
    conv1 = _conv_layer(image, 32, 9, 1)
    conv2 = _conv_layer(conv1, 64, 3, 2)
    conv3 = _conv_layer(conv2, 128, 3, 2)
    resid1 = _residual_block(conv3, 3)
    resid2 = _residual_block(resid1, 3)
    resid3 = _residual_block(resid2, 3)
    resid4 = _residual_block(resid3, 3)
    resid5 = _residual_block(resid4, 3)
#     conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)
#     conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)
    conv_up1 = upsampling(resid5, 64, 3)
    conv_up2 = upsampling(conv_up1, 32, 3)
    conv_up3 = _conv_layer(conv_up2, 3, 9, 1, relu=False)
    preds = tf.nn.sigmoid(conv_up3) * 255.
    return preds
def _conv_layer(net, num_filters, filter_size, strides, relu=True):
    weights_init = _conv_init_vars(net, num_filters, filter_size)
    strides_shape = [1, strides, strides, 1]
    net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')
    net = _instance_norm(net)
    if relu:
        net = tf.nn.relu(net)
    return net
def upsampling(net, num_filters, filter_size):
    net = tf.image.resize_nearest_neighbor(net, [tf.shape(net)[1] * 2, tf.shape(net)[2] * 2])
    weights_init = _conv_init_vars(net, num_filters, filter_size)
    net = tf.nn.conv2d(net, weights_init, [1, 1, 1, 1], padding='SAME')
    return _instance_norm(net)

def _residual_block(net, filter_size=3):
    tmp = _conv_layer(net, 128, filter_size, 1)
    return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)

def _instance_norm(net, train=True):
    batch, rows, cols, channels = [i.value for i in net.get_shape()]
    var_shape = [channels]
    mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
    shift = tf.Variable(tf.zeros(var_shape))
    scale = tf.Variable(tf.ones(var_shape))
    epsilon = 1e-3
    normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
    return scale * normalized + shift

def _conv_init_vars(net, out_channels, filter_size, transpose=False):
    _, rows, cols, in_channels = [i.value for i in net.get_shape()]
    if not transpose:
        weights_shape = [filter_size, filter_size, in_channels, out_channels]
    else:
        weights_shape = [filter_size, filter_size, out_channels, in_channels]
    weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=0.02, seed=1), dtype=tf.float32)
    return weights_init

参考文献

[1]Perceptual Losses for Real-Time Style Transfer
and Super-Resolution

你可能感兴趣的:(风格迁移)