ACGAN
文章目录
- ACGAN
- ACGAN & CGAN
- ACGAN的损失函数
- 基于mnist数据集的ACGAN实现
ACGAN & CGAN
- CGAN通过在生成器和判别器中均使用标签信息进行训练,仅能产生特定标签的数据
- ACGAN是CGAN的另一种实现,既使用标签信息进行训练,同时也重建标签信息
- 生成器的输入包括class 和 noise 两部分,其中class为训练数据的标签(batch_size,channel,height,width)
- 判别器的输入为图片(生成图片和真实图片),输出为两部分:
- 1.源数据真假的判断,形状为(batch_size,1)
- 2.输入数据的分类结果,形状为(batch_size,class_num)
- 所以判别器的最后一层有连个并列的全连接层,分别到这两部分的输出结果,即判别器的输出有两个张量(真假判断张量和分类结果张量)
ACGAN的损失函数
- 对判别器而言,既希望分类正确,又希望能够正确分别数据的真/假
- 判别器的损失函数: L D = L S + L C L_D=L_S+L_C LD=LS+LC
- 判断真假损失: L S = E [ log P ( S = r e a l ∣ x r e a l ) ] + E [ log P ( S = f a k e ∣ x f a k e ) ] L_S=E[\log{P(S=real|x_{real})}]+E[\log{P(S=fake|x_{fake})}] LS=E[logP(S=real∣xreal)]+E[logP(S=fake∣xfake)]
- 分类损失: L C = E [ log P ( C = c ∣ x r e a l ) ] + E [ log P ( C = c ∣ x f a k e ) ] L_C=E[\log{P(C=c|x_{real})}]+E[\log{P(C=c|x_{fake})}] LC=E[logP(C=c∣xreal)]+E[logP(C=c∣xfake)]
- 对生成器而言,希望能够分类正确,但是希望判别器不能分辨数据的真假(由此形成对抗结构)
- 生成器的损失函数: L D = L C + L S L_D=L_C+L_S LD=LC+LS
- 判断真假损失: L S = E [ log P ( S = f a k e ∣ x f a k e ) ] L_S=E[\log{P(S=fake|x_{fake})}] LS=E[logP(S=fake∣xfake)]
- 分类损失: L C = E [ log P ( C = c ∣ x f a k e ) ] L_C=E[\log{P(C=c|x_{fake})}] LC=E[logP(C=c∣xfake)]
基于mnist数据集的ACGAN实现
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
(images,labels),(_,_)=keras.datasets.mnist.load_data()
images=np.expand_dims(images,-1)
images=images/127.5 - 1
dataset=tf.data.Dataset.from_tensor_slices((images,labels))
BATCH_SIZE=128
noise_dim=50
BUFFER_SIZE=images.shape[0]
dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#-----------------------------------------------------------------------
# 定义生成器
def generate_model():
seed=layers.Input(shape=(noise_dim,))
label=layers.Input(())
x=layers.Embedding(10,50,input_length=1)(label)
x=layers.concatenate([seed,x])
x=layers.Dense(3*3*128,use_bias=False)(x)
x=layers.Reshape((3,3,128))(x)
x=layers.BatchNormalization()(x)
x=layers.ReLU()(x)
x=layers.Conv2DTranspose(64,(3,3),strides=2,use_bias=False)(x)
x=layers.BatchNormalization()(x)
x=layers.ReLU()(x)
x=layers.Conv2DTranspose(32,(3,3),strides=2,use_bias=False,padding='same')(x)
x=layers.BatchNormalization()(x)
x=layers.ReLU()(x)
x=layers.Conv2DTranspose(1,(3,3),strides=2,use_bias=False,padding='same')(x)
x=layers.Activation('tanh')(x)
model=keras.models.Model(inputs=(seed,label),outputs=x)
return model
#-----------------------------------------------------------------------
# 定义判别器
def discriminate_model():
image=layers.Input(shape=(28,28,1))
x=layers.Conv2D(32,(3,3),strides=2,padding='same',use_bias=False)(image)
x=layes.BatchNormalization()(x)
x=layers.LeakyReLU()(x)
x=layers.Dropout(0.5)(x)
x=layers.Conv2D(64,(3,3),strides=2,padding='same',use_bias=False)(x)
x=layers.BatchNormalization()(x)
x=layers.LeakyReLU()(x)
x=layers.Dropout(0.5)(x)
x=layers.Conv2D(128,(3,3),strides=2,padding='same',use_bias=False)(x)
x=layers.BatchNormalization()(x)
x=layers.LeakyReLU()(x)
x=layers.Dropout(0.5)
x=layers.Flatten()(x)
# 真假输出
out=layes.Dense(1)(x)
# 分类输出
classifacation_out=layers.Dense(10)(x)
model=keras.models.Model(inputs=image,outputs=(out,classifacation_out))
return model
#-------------------------------------------------------------------------
# 定义损失函数
gen=generate_model()
disc=discriminate_model()
# 由于一方面要判断真假输出,另一方面要判断分类输出,所以损失函数也应该有两个
bce=keras.losses.BinaryCrossentropy(from_logits=True)
cce=keras.losses.SparseCategorialCrossentropy(from_logits=True)
def disc_loss(real_out,real_class_out,fake_out,label):
real_loss=bce(tf.ones_like(real_out),real_out)
fake_loss=bce(tf.zeros_like(fake_out),fake_out)
cat_loss=cce(label,real_class_out)
total_loss=real_loss+fake_loss+cat_loss
return total_loss
def gen_loss(fake_out,fake_class_out,label):
fake_loss=bce(tf.ones_like(fake_out),fake_out)
cat_loss=cce(label,fake_class_out)
total_loss=fake_loss+cat_loss
return total_loss
gen_opt=keras.optimizers.Adam(1e-5)
disc_opt=keras.optimizers.Adam(1e-5)
#------------------------------------------------------------------------
# 自定义训练
def train_step(image,label):
size=label.shape[0]
noise=tf.random.normal((size,noise_dim))
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
gen_imgs=gen((noise,label),training=True)
fake_out,fake_class_out=disc(gen_imgs,training=True)
real_out,real_class_out=disc(image,training=True)
d_loss=disc_loss(real_out,real_class_out,fake_out,label)
g_loss=gen_loss(fake_out,fake_class_out,label)
gen_grad=gen_tape.gradient(g_loss,gen.trainable_variables)
disc_grad=disc_tape.gradient(d_loss,disc.trainable_variables)
gen_opt.apply_gradients(zip(gen_grad,gen.trainable_variables))
disc_opt.apply_gradients(zip(disc_grad,disc.trainable_variables))
#-----------------------------------------------------------------------
# 自定义绘图函数
def plot_gen_image(model,noise,label,epoch_num):
print('Epoch:',epoch_num)
gen_image=model((noise,label),training=False)
# 压缩维度,将28*28*1 转换成28*28的图像
gen_image=tf.squeeze(gen_image)
fig=plt.figure(figsize=(10,1))
for i in range(gen_image.shape[0]):
plt.subplot(1,10,i+1)
plt.imshow((gen_image[i,:,:]+1)/2,cmap='gray')
plt.axes('off')
plt.show()