Pytorch实现第一个生成对抗网络(GAN)

1、GAN 原理

以训练一个形如 “1010” 格式的向量生成器为例:

需要构造两个神经网络为:生成器(Generator)和判别器(Discriminator)其中,

  1. 生成器接受随机噪声,并据此生成一个size=4的向量。
  2. 判别器判断接受的向量是真实样本还是生成器的生成样本,给出输入是真实样本的概率

在训练过程中,生成器的目标是尽量生成真实的数据去欺骗判别器。而判别器的目标就是尽量把生成数据和真实样本区分开。训练过程实际上可以理解为生成器和判别器的博弈。

训练的最终目标:生成器能够生成足以“以假乱真”的数据,判别器不能区分输入的数据是真实数据还是生成数据,原理图如下:
Pytorch实现第一个生成对抗网络(GAN)_第1张图片

2、生成器 & 判别器

判别器和生成器选择了最简单的MLP网络,隐藏层都设置为包含3个神经元(非必须)

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        # 调用父类的构造函数,初始化父类
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
        	nn.Linear(4, 3),
        	nn.Sigmoid(),
        	nn.Linear(3, 1), 
        	nn.Sigmoid()
        )
        # 创建损失函数
        self.loss_function = nn.MSELoss()
        # 创建优化器,随机梯度下降
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
        
    def forward(self, inputs):
        return self.model(inputs)
        
    def train(self, inputs, targets):
        # 计算网络的输出值
        outputs = self.forward(inputs)
        loss = self.loss_function(outputs, targets)
        # 反向传播
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
        	nn.Linear(1, 3),
        	nn.Sigmoid(),
        	nn.Linear(3, 4),
        	nn.Sigmoid()
        )
        # 创建优化器,随机梯度下降
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
        
    def forward(self, inputs):
        return self.model(inputs)
        
    def train(self, discriminator, inputs, targets):
        # 生成器输出
        gen_data = self.forward(inputs)
        # 判别器预测
        pred_data = discriminator(gen_data)
        # 计算损失
        self.optimiser.zero_grad()
        loss = discriminator.loss_function(pred_data, targets)
        # 从判别器误差开始,反向传播误差梯度到生成器
        loss.backward()
        # 用生成器的优化器更新自身参数
        self.optimiser.step()

3、GAN的训练方法

训练一个GAN需要实现判别器和生成器的同步优化。如果判别器的分类能力很强,生成器的分类能力很弱,不能很好地训练生成器,反之亦然。

一般来说,GAN每个step的训练包含三个步骤:

  1. 使用真实样本训练判别器(提高判别器识别真实样本的能力)
  2. 使用生成样本训练判别器(提高判别器识别生成样本的能力)
  3. 训练生成器(提高生成器生成样本的能力)

代码示例如下:

# 生成1010格式的训练正样本(引入高斯噪声)
def generate_real():
    real_data = torch.FloatTensor([
        random.uniform(0.8, 1.0),
        random.uniform(0.0, 0.2),
        random.uniform(0.8, 1.0),
        random.uniform(0.0, 0.2),
    ])
    return real_data

# 创建判别器和生成器
discriminator = Discriminator()
generator = Generator()

for i in range(10000):
    # 用真实样本训练判别器, target=1
    discriminator.train(generate_real(), torch.FloatTensor([1.0]))
    # 用生成样本训练判别器, target=0
    discriminator.train(
    	generator(torch.FloatTensor([0.5])).detach(), 
    	torch.FloatTensor([0.0])
    )
    # 训练生成器, target=1
    generator.train(
    	discriminator, 
    	torch.FloatTensor([0.5]), 
    	torch.FloatTensor([1.0])
    )

代码解析:

  1. 训练前两步的 target 分别取 1、0 很好理解,就是分别用正样本、负样本训练判别网络。
  2. 对于第三步中 target 取为 1 可以有如下理解:如果生成网络成功欺骗过判别器,则对其奖励,反之惩罚;
    所谓的“成功欺骗”就是判别网络对于生成样本的预测的判别结果为 1,所以 target 取 1。
  3. 使用判别器的loss function计算loss,梯度传播到生成器,利用生成器的优化器更新参数,判别器的参数保持不变。
  4. 第二步中的 detach 是将梯度的传播在计算图上阻断,避免计算生成网络中的梯度,节省计算成本。具体用法可以参考pytorch’doc——Tensor.detach()

4、效果分析

生成结果如下:
Pytorch实现第一个生成对抗网络(GAN)_第2张图片

判别器loss变换如下:
Pytorch实现第一个生成对抗网络(GAN)_第3张图片

趋势分析:

  1. 从0.25 下降,表示判别器的识别能力增强,之后再次上升到0.25,表示生成的样本更能“以假乱真”。
  2. 收敛到0.25的原因:采用MSE损失函数,在判断不出来的时候选择介于0~1中间的0.5作为输出,此时的loss为0.25

完整代码,以及更多示例参见 Github: https://github.com/xuzf-git/Algorithm-Toy

你可能感兴趣的:(算法学习记录,Python,神经网络,pytorch,深度学习)