discriminator:
判别器网络结构
经过5个卷积层(con+ bn + leaky_relu), 第5个是将通道数降为1,便于计算损失
(?,256,256,3) ->1 (?,128,128,64) ->2 (?,64,64,128) -> 3(?,32,32,256) ->4 (?,16,16,512) -> 5(?,16,16,1)
generator:
生成器网络结构:
(?,256,256,3) ->1 (?,256,256,64) ->2 (?,128,128,128) -> 3(?,64,64,256) ->4 (?,64,64,256) (是再接9个残差模块得到的)-> 5(?,128,128,128)(上采样)->6(?,256,256,64) ->7(?,256,256,3)
网络整体结构:
class CycleGAN:
def __init__(self,
X_train_file='',
Y_train_file='',
batch_size=1,
image_size=256,
use_lsgan=True,
norm='instance',
lambda1=10,
lambda2=10,
learning_rate=2e-4,
beta1=0.5,
ngf=64
):
"""
Args:
X_train_file: string, X tfrecords file for training
Y_train_file: string Y tfrecords file for training
batch_size: integer, batch size
image_size: integer, image size
lambda1: integer, weight for forward cycle loss (X->Y->X)
lambda2: integer, weight for backward cycle loss (Y->X->Y)
use_lsgan: boolean
norm: 'instance' or 'batch'
learning_rate: float, initial learning rate for Adam
beta1: float, momentum term of Adam
ngf: number of gen filters in first conv layer
"""
self.lambda1 = lambda1 # 10
self.lambda2 = lambda2 # 10
self.use_lsgan = use_lsgan # True
use_sigmoid = not use_lsgan # False
self.batch_size = batch_size # 1
self.image_size = image_size # 256
self.learning_rate = learning_rate # 2e-4
self.beta1 = beta1 # 0.5
self.X_train_file = X_train_file # 'data/tfrecords/apple.tfrecords'
self.Y_train_file = Y_train_file # 'data/tfrecords/orange.tfrecords'
self.is_training = tf.placeholder_with_default(True, shape=[], name='is_training') # True
self.G = Generator('G', self.is_training, ngf=ngf, norm=norm,
image_size=image_size) # True 64 'instance' 256 #(1,256,256,3) x域 -> y域 生成器
self.D_Y = Discriminator('D_Y',
self.is_training, norm=norm,
use_sigmoid=use_sigmoid) # True 'instance' False #( 1, 16, 16, 1) y域 判别器
self.F = Generator('F', self.is_training, norm=norm,
image_size=image_size) # True 'instance' 256 #(1,256,256,3) y域 -> x域 生成器
self.D_X = Discriminator('D_X',
self.is_training, norm=norm,
use_sigmoid=use_sigmoid) # True 'instance' False #( 1, 16, 16, 1) x域 判别器
self.fake_x = tf.placeholder(tf.float32,
shape=[batch_size, image_size, image_size, 3]) # (1,256,256,3)
self.fake_y = tf.placeholder(tf.float32,
shape=[batch_size, image_size, image_size, 3]) # (1,256,256,3)
def model(self):
X_reader = Reader(self.X_train_file, name='X',
image_size=self.image_size, batch_size=self.batch_size)
Y_reader = Reader(self.Y_train_file, name='Y',
image_size=self.image_size, batch_size=self.batch_size)
x = X_reader.feed() # apple 数据 #此处已传入图像
y = Y_reader.feed() # orange 数据 # x,y是随机来自两个数据集
cycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y) # 论文中公式(2)
# X -> Y apple->orange
fake_y = self.G(x) # X生成Y (1,256,256,3)
G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan) # 计算X生成的Y,再通过判别器Y的损失
G_loss = G_gan_loss + cycle_loss
D_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y,
use_lsgan=self.use_lsgan) # 判别器Y, 真的y, 生成的Y #论文中公式(1)
# Y -> X orange->apple
fake_x = self.F(y) # Y生成X (1,256,256,3)
F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan)
F_loss = F_gan_loss + cycle_loss
D_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan)
# summary
tf.summary.histogram('D_Y/true', self.D_Y(y))
tf.summary.histogram('D_Y/fake', self.D_Y(self.G(x)))
tf.summary.histogram('D_X/true', self.D_X(x))
tf.summary.histogram('D_X/fake', self.D_X(self.F(y)))
tf.summary.scalar('loss/G', G_gan_loss)
tf.summary.scalar('loss/D_Y', D_Y_loss)
tf.summary.scalar('loss/F', F_gan_loss)
tf.summary.scalar('loss/D_X', D_X_loss)
tf.summary.scalar('loss/cycle', cycle_loss)
tf.summary.image('X/generated', utils.batch_convert2int(self.G(x)))
tf.summary.image('X/reconstruction', utils.batch_convert2int(self.F(self.G(x))))
tf.summary.image('Y/generated', utils.batch_convert2int(self.F(y)))
tf.summary.image('Y/reconstruction', utils.batch_convert2int(self.G(self.F(y))))
return G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x
x域 -> y域 y域 -> x域
G(x) -> y’ F(y) -> x’
F(G(x)) = F(y’) -> x’’ 相当于用x转到y域的函数G将图像x转为y,再用y域转x域的函数F将图像y转为x,做逆操作,求图像x两次转变的损失误差,
tf.reduce_mean(tf.abs(F(G(x)) - x)), 同理: tf.reduce_mean(tf.abs(G(F(y)) - y))
所以:
cycle consistency loss (L1 norm)循环一致性损失函数:
def cycle_consistency_loss(self, G, F, x, y): # 论文中公式(2)
""" cycle consistency loss (L1 norm)
"""
forward_loss = tf.reduce_mean(tf.abs(F(G(x)) - x))
backward_loss = tf.reduce_mean(tf.abs(G(F(y)) - y))
loss = self.lambda1 * forward_loss + self.lambda2 * backward_loss # 10 10
return loss
生成器损失函数
def generator_loss(self, D, fake_y, use_lsgan=True): ##论文中公式(1)改进的后部分
"""
fool discriminator into believing that G(x) is real
该损失函数的作用是将x->y ,得到的图像y'与域y的图像分布越来越趋近,其中
D是y域的判别网络,
fake_y: 是x -> y 生成的假的图像,大小为(1,256,256,3),
反之亦然
"""
# self.D_Y, fake_y, True
if use_lsgan:
# use mean squared error
# print('ssssswwg',tf.Session().run(tf.shape(D(fake_y)))) #(1,16,16,1)
# print('ssssswwg',tf.Session().run(tf.shape(tf.squared_difference(D(fake_y), REAL_LABEL)))) #(1,16,16,1)
loss = tf.reduce_mean(
tf.squared_difference(D(fake_y), REAL_LABEL)) # D=D_Y D(fake_y)是判别网络得到的结果大小(1,16,16,1) REAL_LABEL=0.9 将值接近于0.9就认为可以以假乱真,相当于做了一个平滑操作
else:
# heuristic, non-saturating loss
loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2
return loss
判别器损失函数
def discriminator_loss(self, D, y, fake_y, use_lsgan=True): # 论文中公式(1)改进的前部分
"""
该损失函数的作用是将x->y ,得到的图像y'与域y的图像分布越来越趋近,其中
D是y域的判别网络,
y是y域的一张图像,
fake_y: 是x -> y 生成的假的图像,大小为(1,256,256,3)
Returns:
loss: scalar
"""
if use_lsgan: # self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan
# use mean squared error
error_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL)) # (1,16,16,1) 0.9 均方差,使得y图像,在y域的判别器下依然为真, 真实的损失
# print('ssssswwgdy',tf.Session().run(tf.shape(D(y))))#(1,16,16,1)
error_fake = tf.reduce_mean(tf.square(D(fake_y))) #均方差,判别x域的图像x到y域的图像后, y域的判别器可以区分出其为假图, 生成的损失, 减去0省略了 1是为真, 0是为假
# print('ssssswwgdfy',tf.Session().run(tf.shape(D(fake_y))))#(1,16,16,1)
else:
# use cross entropy
error_real = -tf.reduce_mean(ops.safe_log(D(y)))
error_fake = -tf.reduce_mean(ops.safe_log(1 - D(fake_y)))
loss = (error_real + error_fake) / 2 # 判别器损失
return loss