生成对抗网络(Generative Adversarial Network, GAN)是一种深度学习模型框架,由 Ian Goodfellow 等人在 2014 年提出。GAN 由 生成器(Generator) 和 判别器(Discriminator) 两个对抗网络组成,通过彼此博弈的方式训练,从而生成与真实数据分布极为相似的高质量数据。GAN 在图像生成、文本生成、数据增强等领域中有广泛应用。
GAN 的核心是两个神经网络之间的对抗:
生成器(Generator):
判别器(Discriminator):
两者通过对抗学习(min-max 游戏)达到一个动态平衡,使得生成器生成的数据逐渐逼近真实数据分布。
GAN 的目标是通过以下损失函数进行优化:
其中:
生成器 G 的目标是最小化 D(G(z)) 的概率(使判别器认为生成数据是真实的),而判别器 D 的目标是最大化其正确判断的概率。
初始化模型:随机初始化生成器和判别器的参数。
训练判别器 D:
训练生成器 G:
重复以上过程,直到生成数据的质量达到目标。
以下是一个使用 TensorFlow/Keras 实现简单 GAN 的代码示例:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Flatten, Reshape
# 创建生成器
def build_generator(latent_dim):
model = Sequential([
Dense(128, input_dim=latent_dim),
LeakyReLU(alpha=0.2),
Dense(256),
LeakyReLU(alpha=0.2),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(784, activation='tanh'),
Reshape((28, 28, 1))
])
return model
# 创建判别器
def build_discriminator():
model = Sequential([
Flatten(input_shape=(28, 28, 1)),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(256),
LeakyReLU(alpha=0.2),
Dense(1, activation='sigmoid')
])
return model
# 定义损失函数和优化器
latent_dim = 100
generator = build_generator(latent_dim)
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 构建 GAN 模型
discriminator.trainable = False
gan = Sequential([generator, discriminator])
gan.compile(optimizer='adam', loss='binary_crossentropy')
# 数据准备(MNIST 数据集)
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 127.5 - 1.0 # 归一化到 [-1, 1]
X_train = np.expand_dims(X_train, axis=-1)
# 训练 GAN
batch_size = 64
epochs = 10000
for epoch in range(epochs):
# 随机选取真实样本
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]
# 生成伪造样本
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_imgs = generator.predict(noise)
# 训练判别器
real_y = np.ones((batch_size, 1))
fake_y = np.zeros((batch_size, 1))
d_loss_real = discriminator.train_on_batch(real_imgs, real_y)
d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_y)
# 训练生成器
g_loss = gan.train_on_batch(noise, real_y)
# 输出训练进度
if epoch % 1000 == 0:
print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")
运行结果
2/2 [==============================] - 0s 2ms/step
Epoch 0: D Loss Real: 0.829259991645813, D Loss Fake: 0.6967335343360901, G Loss: 0.8764752149581909
2/2 [==============================] - 0s 3ms/step
2/2 [==============================] - 0s 2ms/step
2/2 [==============================] - 0s 3ms/step
2/2 [==============================] - 0s 2ms/step
2/2 [==============================] - 0s 2ms/step
......
GAN 的对抗思想极具创新性,为生成任务提供了一种全新的解决方案,是深度学习领域的里程碑技术之一。