基于CycleGAN的性别变换方法

GAN的简介

近年来,GAN(生成对抗式网络)成功地应用于图像生成、图像编辑和和表达学习等方面。最小化对抗损失使得生成的图像看起来真实。GAN的基本原理为:

  • 生成器G是生成图片的网络,接收一个随机的噪声z,生成图片G(z)。其目标是尽量生成真实的图片去欺骗判别网络D。
  • 判别器D是判别一张图片是否为真实。输入一张图片x,输出D(x)为x为真实图片的概率。其目的是尽量把生成器生成的图片和真实的图片区别出来。


    基于CycleGAN的性别变换方法_第1张图片
    GAN网络

在理想情况下,生成器可以生成足以以假乱真的图片。而判别器难以辨别生成器生成的图片是否为真。

GAN的损失函数为:


GAN的损失函数

CycleGAN原理

图像与图像之间的变换

在传统的CNN方法中,图像与图像之间的变换是通过CNN来学习转移参数。
而本文的cycleGAN算法可以直接从一个图像生成另一个图像来实现图像之间的变换。

CycleGAN

目的:学习域X与域Y之间的映射关系。在CycleGAN模型中包括两个映射:X->Y, Y->X。如下图所示。


基于CycleGAN的性别变换方法_第2张图片
CycleGAN网络

在该网络中,存在两个域之间分别转换的生成器,以及每个生成器对应的判别器。目标函数中包括两项:

  • 对抗损失:使用控制生成的图像为目标域的图像。
对抗损失
  • cycle loss:为了防止两个生成器之间是相互矛盾的。


    cycle损失

在本项目中用来实现男女性别两个域之间的转换。

代码解析

### generator
conv(7, 7, 32)
conv(3, 3, 64)
conv(3, 3, 128)
res_block * 6 
deconv(3, 3, 64)
deconv(3, 3, 32)
conv(7, 7, 3)

### discriminator
conv(3, 3, 64)
conv(3, 3, 128)
conv(3, 3, 256)
conv(3, 3, 512)
conv(4, 4, 512)
### resnet_block
def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
    with tf.variable_scope(name):
        out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
        return tf.nn.relu(out_res + inputres)
### generator
def build_generator_resnet_9blocks_tf(inputgen, name="generator", skip=False):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        padding = "REFLECT"

        pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ ks, ks], [0, 0]], padding)
        o_c1 = layers.general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02, name="c1")
        o_c2 = layers.general_conv2d(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
        o_c3 = layers.general_conv2d(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3")

        o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
        o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
        o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
        o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
        o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
        o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
        o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
        o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
        o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)

        o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4")
        o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,"SAME", "c5")
        o_c6 = layers.general_conv2d(o_c5, IMG_CHANNELS, f, f, 1, 1, 0.02, "SAME", "c6",do_norm=False, do_relu=False)

        if skip is True:
            out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
        else:
            out_gen = tf.nn.tanh(o_c6, "t1")

        return out_gen
### discriminator
def discriminator_tf(inputdisc, name="discriminator"):
    with tf.variable_scope(name):
        f = 4
        o_c1 = layers.general_conv2d(inputdisc, ndf, f, f, 2, 2,0.02, "SAME", "c1", do_norm=False, relufactor=0.2)
        o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2)
        o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)
        o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 1, 1,0.02, "SAME", "c4", relufactor=0.2)
        o_c5 = layers.general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False
        )
        return o_c5
### layers.py
import tensorflow as tf
def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
    with tf.variable_scope(name):
        if alt_relu_impl:
            f1 = 0.5 * (1 + leak)
            f2 = 0.5 * (1 - leak)
            return f1 * x + f2 * abs(x)
        else:
            return tf.maximum(x, leak * x)
def instance_norm(x):
    with tf.variable_scope("instance_norm"):
        epsilon = 1e-5
        mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
        scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02
        ))
        offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0)
        )
        out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
        return out
def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
                   padding="VALID", name="conv2d", do_norm=True, do_relu=True,
                   relufactor=0):
    with tf.variable_scope(name):

        conv = tf.contrib.layers.conv2d( inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev
            ), biases_initializer=tf.constant_initializer(0.0))
        if do_norm:
            conv = instance_norm(conv)
        if do_relu:
            if(relufactor == 0):
                conv = tf.nn.relu(conv, "relu")
            else:
                conv = lrelu(conv, relufactor, "lrelu")
        return conv

def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
                     stddev=0.02, padding="VALID", name="deconv2d",
                     do_norm=True, do_relu=True, relufactor=0):
    with tf.variable_scope(name):

        conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None,weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer(0.0))
        if do_norm:
            conv = instance_norm(conv)
            # conv = tf.contrib.layers.batch_norm(conv, decay=0.9,
            # updates_collections=None, epsilon=1e-5, scale=True,
            # scope="batch_norm")
        if do_relu:
            if(relufactor == 0):
                conv = tf.nn.relu(conv, "relu")
            else:
                conv = lrelu(conv, relufactor, "lrelu")
        return conv

测试的结果:


测试的结果

你可能感兴趣的:(基于CycleGAN的性别变换方法)