cyclegan原理与分析

discriminator:
判别器网络结构
经过5个卷积层(con+ bn + leaky_relu), 第5个是将通道数降为1,便于计算损失
(?,256,256,3) ->1 (?,128,128,64) ->2 (?,64,64,128) -> 3(?,32,32,256) ->4 (?,16,16,512) -> 5(?,16,16,1)
generator:
生成器网络结构:
(?,256,256,3) ->1 (?,256,256,64) ->2 (?,128,128,128) -> 3(?,64,64,256) ->4 (?,64,64,256) (是再接9个残差模块得到的)-> 5(?,128,128,128)(上采样)->6(?,256,256,64) ->7(?,256,256,3)
网络整体结构:

class CycleGAN:
    def __init__(self,
                 X_train_file='',
                 Y_train_file='',
                 batch_size=1,
                 image_size=256,
                 use_lsgan=True,
                 norm='instance',
                 lambda1=10,
                 lambda2=10,
                 learning_rate=2e-4,
                 beta1=0.5,
                 ngf=64
                 ):
        """
        Args:
          X_train_file: string, X tfrecords file for training
          Y_train_file: string Y tfrecords file for training
          batch_size: integer, batch size
          image_size: integer, image size
          lambda1: integer, weight for forward cycle loss (X->Y->X)
          lambda2: integer, weight for backward cycle loss (Y->X->Y)
          use_lsgan: boolean
          norm: 'instance' or 'batch'
          learning_rate: float, initial learning rate for Adam
          beta1: float, momentum term of Adam
          ngf: number of gen filters in first conv layer
        """
        self.lambda1 = lambda1  # 10
        self.lambda2 = lambda2  # 10
        self.use_lsgan = use_lsgan  # True
        use_sigmoid = not use_lsgan  # False
        self.batch_size = batch_size  # 1
        self.image_size = image_size  # 256
        self.learning_rate = learning_rate  # 2e-4
        self.beta1 = beta1  # 0.5
        self.X_train_file = X_train_file  # 'data/tfrecords/apple.tfrecords'
        self.Y_train_file = Y_train_file  # 'data/tfrecords/orange.tfrecords'

        self.is_training = tf.placeholder_with_default(True, shape=[], name='is_training')  # True

        self.G = Generator('G', self.is_training, ngf=ngf, norm=norm,
                           image_size=image_size)  # True 64 'instance' 256   #(1,256,256,3)  x域 -> y域 生成器
        self.D_Y = Discriminator('D_Y',
                                 self.is_training, norm=norm,
                                 use_sigmoid=use_sigmoid)  # True 'instance' False    #( 1, 16, 16,  1) y域 判别器

        self.F = Generator('F', self.is_training, norm=norm,
                           image_size=image_size)  # True 'instance' 256     #(1,256,256,3)   y域 -> x域 生成器
        self.D_X = Discriminator('D_X',
                                 self.is_training, norm=norm,
                                 use_sigmoid=use_sigmoid)  # True 'instance' False    #( 1, 16, 16,  1)  x域 判别器

        self.fake_x = tf.placeholder(tf.float32,
                                     shape=[batch_size, image_size, image_size, 3])  # (1,256,256,3)
        self.fake_y = tf.placeholder(tf.float32,
                                     shape=[batch_size, image_size, image_size, 3])  # (1,256,256,3)

    def model(self):
        X_reader = Reader(self.X_train_file, name='X',
                          image_size=self.image_size, batch_size=self.batch_size)
        Y_reader = Reader(self.Y_train_file, name='Y',
                          image_size=self.image_size, batch_size=self.batch_size)

        x = X_reader.feed()  # apple 数据 #此处已传入图像
        y = Y_reader.feed()  # orange 数据 #       x,y是随机来自两个数据集

        cycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y)  # 论文中公式(2)

        # X -> Y  apple->orange
        fake_y = self.G(x)  # X生成Y  (1,256,256,3)
        G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan)  # 计算X生成的Y,再通过判别器Y的损失

        G_loss = G_gan_loss + cycle_loss

        D_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y,
                                           use_lsgan=self.use_lsgan)  # 判别器Y, 真的y, 生成的Y  #论文中公式(1)

        # Y -> X  orange->apple
        fake_x = self.F(y)  # Y生成X  (1,256,256,3)

        F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan)

        F_loss = F_gan_loss + cycle_loss

        D_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan)

        # summary
        tf.summary.histogram('D_Y/true', self.D_Y(y))
        tf.summary.histogram('D_Y/fake', self.D_Y(self.G(x)))
        tf.summary.histogram('D_X/true', self.D_X(x))
        tf.summary.histogram('D_X/fake', self.D_X(self.F(y)))

        tf.summary.scalar('loss/G', G_gan_loss)
        tf.summary.scalar('loss/D_Y', D_Y_loss)
        tf.summary.scalar('loss/F', F_gan_loss)
        tf.summary.scalar('loss/D_X', D_X_loss)
        tf.summary.scalar('loss/cycle', cycle_loss)

        tf.summary.image('X/generated', utils.batch_convert2int(self.G(x)))
        tf.summary.image('X/reconstruction', utils.batch_convert2int(self.F(self.G(x))))
        tf.summary.image('Y/generated', utils.batch_convert2int(self.F(y)))
        tf.summary.image('Y/reconstruction', utils.batch_convert2int(self.G(self.F(y))))

        return G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x

x域 -> y域 y域 -> x域
G(x) -> y’ F(y) -> x’
F(G(x)) = F(y’) -> x’’ 相当于用x转到y域的函数G将图像x转为y,再用y域转x域的函数F将图像y转为x,做逆操作,求图像x两次转变的损失误差,
tf.reduce_mean(tf.abs(F(G(x)) - x)), 同理: tf.reduce_mean(tf.abs(G(F(y)) - y))
所以:
cycle consistency loss (L1 norm)循环一致性损失函数:

    def cycle_consistency_loss(self, G, F, x, y):  # 论文中公式(2)
        """ cycle consistency loss (L1 norm)
        """
        forward_loss = tf.reduce_mean(tf.abs(F(G(x)) - x))
        backward_loss = tf.reduce_mean(tf.abs(G(F(y)) - y))
        loss = self.lambda1 * forward_loss + self.lambda2 * backward_loss  # 10  10
        return loss

生成器损失函数

    def generator_loss(self, D, fake_y, use_lsgan=True):  ##论文中公式(1)改进的后部分
        """ 
         fool discriminator into believing that G(x) is real
         该损失函数的作用是将x->y ,得到的图像y'与域y的图像分布越来越趋近,其中
         D是y域的判别网络, 
         fake_y: 是x -> y 生成的假的图像,大小为(1,256,256,3),
         反之亦然
        """
        # self.D_Y, fake_y, True
        if use_lsgan:
            # use mean squared error
            # print('ssssswwg',tf.Session().run(tf.shape(D(fake_y)))) #(1,16,16,1)
            # print('ssssswwg',tf.Session().run(tf.shape(tf.squared_difference(D(fake_y), REAL_LABEL)))) #(1,16,16,1)
            loss = tf.reduce_mean(
                tf.squared_difference(D(fake_y), REAL_LABEL))  # D=D_Y D(fake_y)是判别网络得到的结果大小(1,16,16,1)  REAL_LABEL=0.9 将值接近于0.9就认为可以以假乱真,相当于做了一个平滑操作

        else:
            # heuristic, non-saturating loss
            loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2
        return loss

判别器损失函数

    def discriminator_loss(self, D, y, fake_y, use_lsgan=True):  # 论文中公式(1)改进的前部分
        """ 
       该损失函数的作用是将x->y ,得到的图像y'与域y的图像分布越来越趋近,其中
         D是y域的判别网络, 
         y是y域的一张图像,
         fake_y: 是x -> y 生成的假的图像,大小为(1,256,256,3)
        Returns:
          loss: scalar
        """
        if use_lsgan:  # self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan
            # use mean squared error
            error_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL))  # (1,16,16,1) 0.9  均方差,使得y图像,在y域的判别器下依然为真, 真实的损失
            # print('ssssswwgdy',tf.Session().run(tf.shape(D(y))))#(1,16,16,1)
            error_fake = tf.reduce_mean(tf.square(D(fake_y)))  #均方差,判别x域的图像x到y域的图像后, y域的判别器可以区分出其为假图, 生成的损失, 减去0省略了  														 1是为真, 0是为假
            # print('ssssswwgdfy',tf.Session().run(tf.shape(D(fake_y))))#(1,16,16,1)
        else:
            # use cross entropy
            error_real = -tf.reduce_mean(ops.safe_log(D(y)))
            error_fake = -tf.reduce_mean(ops.safe_log(1 - D(fake_y)))
        loss = (error_real + error_fake) / 2  # 判别器损失
        return loss

你可能感兴趣的:(gan)