Tensorflow2.0:CycleGan

CycleGAN

cycleGAN可用于风格迁移,用来处理图像转换过程中数据unpair的问题。如下图所示,原本正常马可以变成斑马,同时保持背景不变。其本质是两个对称的GAN,构成一个环形网络。两个GAN共享两个生成器,并各自带一个判别器,即总共两个判别器和两个生成器。

Tensorflow2.0:CycleGan_第1张图片
Loss的计算
本质上还是Gan,所以弄清楚其损失函数的计算方法就可以了。

# 鉴别器的loss
def discriminator_loss(disc_of_real_output, disc_of_gen_output, lsgan=True):
    if lsgan: 
        real_loss = keras.losses.mean_squared_error(disc_of_real_output, tf.ones_like(disc_of_real_output))
        generated_loss = tf.reduce_mean(tf.square(disc_of_gen_output))
        total_disc_loss = (real_loss + generated_loss) * 0.5  # 0.5 slows down rate that D learns compared to G
    else:  # Use vanilla GAN loss
        raise NotImplementedError
        real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_of_real_output),
                                                    logits=disc_of_real_output)
        generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(disc_of_gen_output),
                                                         logits=disc_of_gen_output)

        total_disc_loss = real_loss + generated_loss

    return total_disc_loss
# 生成器的loss
def generator_loss(disc_of_gen_output, lsgan=True):
    if lsgan:  # Use least squares loss
        # gen_loss = tf.reduce_mean(tf.squared_difference(disc_of_gen_output, 1))
        gen_loss = keras.losses.mean_squared_error(disc_of_gen_output, tf.ones_like(disc_of_gen_output))
    else:  # Use vanilla GAN loss
        raise NotImplementedError
        gen_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_generated_output),
                                                   logits=disc_generated_output)
        # l1_loss = tf.reduce_mean(tf.abs(target - gen_output)) # Look up pix2pix loss
    return gen_loss
# 循环loss
def cycle_consistency_loss(data_A, data_B, reconstructed_data_A, reconstructed_data_B, cyc_lambda=10):
    loss = tf.reduce_mean(tf.abs(data_A - reconstructed_data_A) + tf.abs(data_B - reconstructed_data_B))
    return cyc_lambda * loss

训练的一些细节也要注意到,一次训练过程如下:

        with tf.GradientTape() as genA2B_tape, tf.GradientTape() as genB2A_tape, \
                tf.GradientTape() as discA_tape, tf.GradientTape() as discB_tape:
            try:
                trainA = next(train_datasetA)
                trainB = next(train_datasetB)
            except tf.errors.OutOfRangeError:
                print("Error, run out of data")
                break
            # Cycle Gan循环机制
            # 生成器:输入A生成B再生成A 输入B生成A再生成B
            # 鉴别器:鉴别A和B生成的A  鉴别B和A生成的B
            genA2B_output = genA2B(trainA, training=True)
            genB2A_output = genB2A(trainB, training=True)
            discA_real_output = discA(trainA, training=True)
            discB_real_output = discB(trainB, training=True)
            discA_fake_output = discA(genB2A_output, training=True)
            discB_fake_output = discB(genA2B_output, training=True)
            reconstructedA = genB2A(genA2B_output, training=True)
            reconstructedB = genA2B(genB2A_output, training=True)
            # 计算损失方法
            # 鉴别器A:真实图片A与1之间的距离+B生成假图片A与0之间的距离(距离可以是sigmoid的交叉熵,也可以是MSE)
            # 生成器A2B:生成假图片B与1之间的距离+循环距离
            # 循环距离=(真实图片A与再生成图片A的绝对差值+真实图片B与再生成图片B的绝对差值)*权值
            discA_loss = discriminator_loss(discA_real_output, discA_fake_output, lsgan=lsgan)
            discB_loss = discriminator_loss(discB_real_output, discB_fake_output, lsgan=lsgan)
            genA2B_loss = generator_loss(discB_fake_output, lsgan=lsgan) + \
                          cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB,
                                                 cyc_lambda=cyc_lambda)
            genB2A_loss = generator_loss(discA_fake_output, lsgan=lsgan) + \
                          cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB,
                                                 cyc_lambda=cyc_lambda)
        # 先更新生成器的系数,再更新鉴别器的系数
        genA2B_gradients = genA2B_tape.gradient(genA2B_loss, genA2B.trainable_variables)
        genB2A_gradients = genB2A_tape.gradient(genB2A_loss, genB2A.trainable_variables)

        discA_gradients = discA_tape.gradient(discA_loss, discA.trainable_variables)
        discB_gradients = discB_tape.gradient(discB_loss, discB.trainable_variables)

        genA2B_optimizer.apply_gradients(zip(genA2B_gradients, genA2B.trainable_variables))
        genB2A_optimizer.apply_gradients(zip(genB2A_gradients, genB2A.trainable_variables))

        discA_optimizer.apply_gradients(zip(discA_gradients, discA.trainable_variables))
        discB_optimizer.apply_gradients(zip(discB_gradients, discB.trainable_variables))

样例展示

Tensorflow2.0:CycleGan_第2张图片

总结

在面对不同的数据集和现实需求,还是要学很多调参技巧的,这个后续更新吧。

你可能感兴趣的:(Tensorflow2.0)