以训练一个形如 “1010” 格式的向量生成器为例:
需要构造两个神经网络为:生成器(Generator)和判别器(Discriminator)其中,
在训练过程中,生成器的目标是尽量生成真实的数据去欺骗判别器。而判别器的目标就是尽量把生成数据和真实样本区分开。训练过程实际上可以理解为生成器和判别器的博弈。
训练的最终目标:生成器能够生成足以“以假乱真”的数据,判别器不能区分输入的数据是真实数据还是生成数据,原理图如下:
判别器和生成器选择了最简单的MLP网络,隐藏层都设置为包含3个神经元(非必须)
# 判别器
class Discriminator(nn.Module):
def __init__(self):
# 调用父类的构造函数,初始化父类
super().__init__()
# 定义神经网络
self.model = nn.Sequential(
nn.Linear(4, 3),
nn.Sigmoid(),
nn.Linear(3, 1),
nn.Sigmoid()
)
# 创建损失函数
self.loss_function = nn.MSELoss()
# 创建优化器,随机梯度下降
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
def forward(self, inputs):
return self.model(inputs)
def train(self, inputs, targets):
# 计算网络的输出值
outputs = self.forward(inputs)
loss = self.loss_function(outputs, targets)
# 反向传播
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
# 生成器
class Generator(nn.Module):
def __init__(self):
super().__init__()
# 定义神经网络
self.model = nn.Sequential(
nn.Linear(1, 3),
nn.Sigmoid(),
nn.Linear(3, 4),
nn.Sigmoid()
)
# 创建优化器,随机梯度下降
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
def forward(self, inputs):
return self.model(inputs)
def train(self, discriminator, inputs, targets):
# 生成器输出
gen_data = self.forward(inputs)
# 判别器预测
pred_data = discriminator(gen_data)
# 计算损失
self.optimiser.zero_grad()
loss = discriminator.loss_function(pred_data, targets)
# 从判别器误差开始,反向传播误差梯度到生成器
loss.backward()
# 用生成器的优化器更新自身参数
self.optimiser.step()
训练一个GAN需要实现判别器和生成器的同步优化。如果判别器的分类能力很强,生成器的分类能力很弱,不能很好地训练生成器,反之亦然。
一般来说,GAN每个step的训练包含三个步骤:
代码示例如下:
# 生成1010格式的训练正样本(引入高斯噪声)
def generate_real():
real_data = torch.FloatTensor([
random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2),
random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2),
])
return real_data
# 创建判别器和生成器
discriminator = Discriminator()
generator = Generator()
for i in range(10000):
# 用真实样本训练判别器, target=1
discriminator.train(generate_real(), torch.FloatTensor([1.0]))
# 用生成样本训练判别器, target=0
discriminator.train(
generator(torch.FloatTensor([0.5])).detach(),
torch.FloatTensor([0.0])
)
# 训练生成器, target=1
generator.train(
discriminator,
torch.FloatTensor([0.5]),
torch.FloatTensor([1.0])
)
代码解析:
detach
是将梯度的传播在计算图上阻断,避免计算生成网络中的梯度,节省计算成本。具体用法可以参考pytorch’doc——Tensor.detach()趋势分析:
完整代码,以及更多示例参见 Github: https://github.com/xuzf-git/Algorithm-Toy