生成的图像是随机的,不可预测,无法控制网络输出特定的图片,生成目标不明确,可控性不强
针对原始GAN不能生成具有特定属性图片的问题,CGAN的核心在于将属性信息y融入生成器G和判别器D中,属性y可以是任何标签类信息
CGAN的缺陷:
生成的图像边缘模糊,分辨率不够
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 处理数据
(images,labels),(_,_)=keras.datasets.mnist.load_data()
images=images/127.5 - 1
images=np.expand_dims(images,-1)
dataset=tf.data.Dataset.from_tensor_slices((images,labels))
BATCH_SIZE=128
noise_dim=50
BUFFER_SIZE=images.shape[0]
dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#--------------------------------------------------------------------
# 定义生成器模型
def generate_model():
seed=layers.Input(shape=(noise_dim,))
label=layers.Input(shape(()))
# 添加一个Embedding层,目的是将label和noise合起来
x=layers.Embedding(10,50,input_length=1)(label)
x=layers.concatenate([seed,x])(x)
x=layers.Dense(3*3*128,use_bias=False)(x)
x=layers.Reshape((3,3,128))(x)
x=layers.BatchNormalization()(x)
x=layers.ReLU()(x)
# 注意:这里没有指定padding的参数,则默认为valid,也就是说,经过这一层后图像会变成(7,7,64)
x=layers.Conv2DTranspose(64,(3,3),strides=(2,2),use_biase=False)(x)
x=layers.BatchNormalization()(x)
x=layers.ReLU()(x)
# 经过该层后,图像变成(14,14,32)
x=layes.Conv2DTranspose(32,(3,3),strides=(2,2),padding='same',use_bias=False)
x=layers.BatchNormalization()(x)
x=layers.ReLU()(x)
# 经过该层后,图像变成(28,28,1)
x=layers.Conv2DTranspose(1,(3,3),strides=(2,2),use_bias=False,padding='same')(x)
x=layers.Activation('tanh')(x)
model=keras.models.Model(inputs=(seed,label),outputs=x)
return model
#------------------------------------------------------------------------
# 定义判别器模型
def discriminate_model():
image=layers.Input(shape=(28,28,1))
label=layers.Input(shape=(()))
# 同样的,一个Embedding层将image和label合起来
x=layers.Embedding(10,28*28,input_length=1)(label)
x=layers.Reshape((28,28,1))(x)
x=layers.concatenate([x,image])
x=layers.Conv2D(32,(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x=layers.BatchNormalization()(x)
x=layers.LeakyReLU()(x)
x=layers.Dropout(0.3)(x)
x=layers.Conv2D(64,(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x=layers.BatchNormalization()(x)
x=layers.LeakyReLU()(x)
x=layers.Dropout(0.3)(x)
x=layers.Conv2D(128,(3,3),strides=(2,2),padding='same',use_bias=False)(x)
x=layers.BatchNormalization()(x)
x=layers.LeakyReLU()(x)
x=layers.Dropout(0.3)(x)
x=layers.Flatten()(x)
out=layers.Dense(1)(x)
model=keras.models.Model(inputs=(image,label),outputs=out)
return model
#-------------------------------------------------------------------
# 实例化对象并自定义损失函数
generator=generate_model()
discriminator=discriminate_model()
bce=keras.losses.BinaryCrossentropy(from_logits=True)
# 这里的损失函数和GAN/DCGAN网络的损失函数并无差别
def disc_loss(real_out,fake_out):
real_loss=bce(tf.ones_like(real_out),real_out)
fake_loss=bce(tf.zeros_like(fake_out),fake_out)
total_loss=real_loss+fake_loss
return total_loss
def gen_loss(fake_out):
return bce(tf.ones_like(fake_out),fake_out)
#---------------------------------------------------------------------
# 定义优化器&自定义训练
gen_opt=keras.optimizers.Adam(1e-5)
disc_opt=keras.optimizers.Adam(1e-5)
@tf.function
def train_step(image,label):
size=label.shape[0]
noise=tf.random.normal((size,noise_dim))
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
gen_image=generator((noise,label),training=True)
real_out=discriminator((image,label),training=True)
fake_out=discriminator((gen_image,label),training=True)
discriminate_loss=disc_loss(real_out,fake_out)
generate_loss=gen_loss(fake_out)
gen_grad=gen_tape.gradient(generate_loss,generator.trainable_variables)
disc_grad=disc_tape.gradient(discriminate_loss,discriminator.trainable_variables)
gen_opt.apply_gradient(zip(gen_grad,generator.trainable_variables))
disc_opt.apple_gradient(zip(disc_grad,discriminator.trainable_varialbes))
#---------------------------------------------------------------------
# 自定义绘图函数
def plot_gen_image(model,noise,label,epoch_num):
print('Epoch:',epoch_num)
gen_image=model((noise,label),training=False)
# 压缩维度,将28*28*1 转换成28*28的图像
gen_image=tf.squeeze(gen_image)
fig=plt.figure(figsize=(10,1))
for i in range(gen_image.shape[0]):
plt.subplot(1,10,i+1)
plt.imshow((gen_image[i,:,:]+1)/2)
plt.axes('off')
plt.show()
#-----------------------------------------------------------------------
# 启动函数
noise_seed=tf.random.normal([10,noise_dim])
label_seed=np.random.randint(0,10,size=(10))
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch,label_batch in dataset:
train_step(image_batch,label_batch)
if epoch % 10 ==0:
plot_gen_image(generator,noise_seed,label_seed,epoch)
plot_gen_image(generator,noise_seed,label_seed,epoch)
EPOCHS=200
train(dataset,EPOCHS)