利用MNIST和GAN进行手写数字图像生成代码

文章目录

  • 1.使用Pytorch提供的BCELoss
  • 2.使用自己定义的BCELoss
  • 3.本次利用MNIST和GAN进行手写数字图像生成的实践总结
      • (1)自定义Loss
      • (2)网络参数的更新
      • (3)学习率

1.使用Pytorch提供的BCELoss

import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt


def generate_noise(num):
    return torch.randn(size=(num, 1, 28, 28))


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Flatten(),
                                    nn.Linear(784, 1568),
                                    nn.ReLU(),
                                    nn.Linear(in_features=1568, out_features=1200),
                                    nn.ReLU(),
                                    nn.Linear(in_features=1200, out_features=784),
                                    nn.Sigmoid(),
                                    nn.Unflatten(dim=1, unflattened_size=torch.Size([1, 28, 28])))

    def forward(self, X):
        Y = self.layers(X)
        return Y


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Flatten(),
                                    nn.Linear(in_features=784, out_features=392),
                                    nn.ReLU(),
                                    nn.Linear(in_features=392, out_features=196),
                                    nn.ReLU(),
                                    nn.Linear(in_features=196, out_features=1),
                                    nn.Sigmoid())

    def forward(self, X):
        Y = self.layers(X)
        return Y


def train_discriminator_one_times(discriminator, generator, train_loader, batch_size,
                                  criterion_discriminator, optimizer_discriminator):
    for datas, _ in train_loader:

        noises = generate_noise(batch_size)
        fakes = generator(noises)

        images = torch.cat([datas, fakes], dim=0)
        discriminator.train()
        dis_result = discriminator(images).reshape(batch_size*2)
        labels = torch.tensor([1-i for i in range(2) for j in range(batch_size)], dtype=torch.float32)
        loss = criterion_discriminator(dis_result, labels)
        optimizer_discriminator.zero_grad()
        loss.sum().backward()
        optimizer_discriminator.step()

    return discriminator


def train_generator_one_times(discriminator, generator, batch_size, criterion_generator,
                              optimizer_generator):
    noises = generate_noise(batch_size)
    generator.train()
    fakes = generator(noises)
    # with torch.no_grad():
    #     discriminator.eval()
    #     dis_result = discriminator(fakes).reshape(batch_size)
    dis_result = discriminator(fakes).reshape(batch_size)
    labels = torch.tensor([1 - i for i in range(1) for j in range(batch_size)], dtype=torch.float32)
    loss = criterion_generator(dis_result, labels)
    optimizer_generator.zero_grad()
    loss.sum().backward()
    optimizer_generator.step()

    return generator


def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)


if __name__ == '__main__':
    batch_size = 400
    epochs = 100
    K = 2
    G = 25
    lr = 0.03
    noise = generate_noise(1)

    MNIST_train_dataset = torchvision.datasets.MNIST(root='MNIST', train=False,
                                                     transform=torchvision.transforms.ToTensor(),
                                                     download=True)
    train_loader = DataLoader(dataset=MNIST_train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    torch.manual_seed(5)
    discriminator = Discriminator()
    discriminator.apply(init_weights)
    generator = Generator()
    generator.apply(init_weights)
    criterion_discriminator = torch.nn.BCELoss()
    criterion_generator = torch.nn.BCELoss()
    optimizer_discriminator = torch.optim.SGD(discriminator.parameters(), lr=lr)
    optimizer_generator = torch.optim.SGD(generator.parameters(), lr=lr)

    generator_state_dict = torch.load('generator_state_dict.pt')
    generator.load_state_dict(generator_state_dict)
    discriminator_state_dict = torch.load('discriminator_state_dict.pt')
    discriminator.load_state_dict(discriminator_state_dict)

    for epoch in range(epochs):
        for k in range(K):
            discriminator = train_discriminator_one_times(discriminator=discriminator, generator=generator,
                                                          train_loader=train_loader, batch_size=batch_size,
                                                          criterion_discriminator=criterion_discriminator,
                                                          optimizer_discriminator=optimizer_discriminator)
            torch.save(discriminator, 'discriminator.pt')
            torch.save(discriminator.state_dict(), 'discriminator_state_dict.pt')
            print(f'Epoch.k {epoch}.{k}, training discriminator has been finished.')
        for g in range(G): # 一共训练 batch_size * G 个噪声样本
            generator = train_generator_one_times(discriminator=discriminator, generator=generator,
                                                  batch_size=batch_size, criterion_generator=criterion_generator,
                                                  optimizer_generator=optimizer_generator)
            torch.save(generator.state_dict(), 'generator_state_dict.pt')
            torch.save(generator, 'generator.pt')
        print(f'Epoch {epoch}, training generator has been finished.')
        with torch.no_grad():
            generator.eval()
            fake = generator(noise)
            plt.imshow(fake[0][0].detach().numpy(), cmap='gray')
            plt.show()

2.使用自己定义的BCELoss

import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import matplotlib.pyplot as plt


def discriminator_loss(dis_result, batch_size):
    loss = 0
    for i in range(batch_size):
        a = torch.log(dis_result[i])
        b = torch.log(1 - dis_result[i + batch_size])
        if (a < torch.tensor(-100)).detach().numpy():
            a = torch.tensor([-100])
        if (b < torch.tensor(-100)).detach().numpy():
            b = torch.tensor([-100])
        loss -= (a + b)
    return loss / (2 * batch_size)


def generator_loss(dis_result, batch_size):
    loss = 0
    for i in range(batch_size):
        a = torch.log(dis_result[i])
        if (a < torch.tensor(-100)).detach().numpy():
            a = torch.tensor([-100])
        loss -= a
    return loss / batch_size


def generate_noise(num):
    return torch.randn(size=(num, 1, 28, 28))


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Flatten(),
                                    nn.Linear(784, 1200),
                                    # nn.BatchNorm1d(1200),
                                    nn.GELU(),
                                    nn.Linear(in_features=1200, out_features=1600),
                                    # nn.BatchNorm1d(1600),
                                    nn.GELU(),
                                    nn.Linear(in_features=1600, out_features=784),
                                    nn.Sigmoid(),
                                    nn.Unflatten(dim=1, unflattened_size=torch.Size([1, 28, 28])))

    def forward(self, X):
        Y = self.layers(X)
        return Y


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Flatten(),
                                    nn.Linear(in_features=784, out_features=392),
                                    # nn.BatchNorm1d(392),
                                    nn.ReLU(),
                                    nn.Linear(in_features=392, out_features=196),
                                    # nn.BatchNorm1d(196),
                                    nn.ReLU(),
                                    nn.Linear(in_features=196, out_features=1),
                                    nn.Sigmoid())

    def forward(self, X):
        Y = self.layers(X)
        return Y


def train_discriminator_one_times(discriminator, generator, train_loader, batch_size,
                                  criterion_discriminator, optimizer_discriminator):
    for datas, _ in train_loader:

        noises = generate_noise(batch_size)
        fakes = generator(noises)

        images = torch.cat([datas, fakes], dim=0)
        discriminator.train()
        dis_result = discriminator(images).reshape(batch_size*2)
        labels = torch.tensor([1-i for i in range(2) for j in range(batch_size)], dtype=torch.float32)
        loss = criterion_discriminator(dis_result, batch_size)
        optimizer_discriminator.zero_grad()
        loss.sum().backward()
        optimizer_discriminator.step()

    return discriminator


def train_generator_one_times(discriminator, generator, batch_size, criterion_generator,
                              optimizer_generator):
    noises = generate_noise(batch_size)
    generator.train()
    fakes = generator(noises)
    # with torch.no_grad():
    #     discriminator.eval()
    #     dis_result = discriminator(fakes).reshape(batch_size)
    dis_result = discriminator(fakes).reshape(batch_size)
    labels = torch.tensor([1 - i for i in range(1) for j in range(batch_size)], dtype=torch.float32)
    loss = criterion_generator(dis_result, batch_size)
    optimizer_generator.zero_grad()
    loss.sum().backward()
    optimizer_generator.step()

    return generator


if __name__ == '__main__':
    batch_size = 400
    epochs = 100
    K = 1
    G = 25
    lr = 0.1

    MNIST_train_dataset = torchvision.datasets.MNIST(root='MNIST', train=False,
                                                     transform=torchvision.transforms.ToTensor(),
                                                     download=True)
    train_loader = DataLoader(dataset=MNIST_train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    torch.manual_seed(5)
    discriminator = Discriminator()
    generator = Generator()
    criterion_discriminator = discriminator_loss
    criterion_generator = generator_loss
    optimizer_discriminator = torch.optim.SGD(discriminator.parameters(), lr=lr)
    optimizer_generator = torch.optim.SGD(generator.parameters(), lr=lr)

    # generator_state_dict = torch.load('generator_state_dict.pt')
    # generator.load_state_dict(generator_state_dict)
    # discriminator_state_dict = torch.load('discriminator_state_dict.pt')
    # discriminator.load_state_dict(discriminator_state_dict)

    for epoch in range(epochs):
        for k in range(K):
            discriminator = train_discriminator_one_times(discriminator=discriminator, generator=generator,
                                                          train_loader=train_loader, batch_size=batch_size,
                                                          criterion_discriminator=criterion_discriminator,
                                                          optimizer_discriminator=optimizer_discriminator)
            # torch.save(discriminator, 'discriminator.pt')
            # torch.save(discriminator.state_dict(), 'discriminator_state_dict.pt')
            print(f'Epoch.k {epoch}.{k}, training discriminator has been finished.')
        for g in range(G): # 一共训练 batch_size * G 个噪声样本
            generator = train_generator_one_times(discriminator=discriminator, generator=generator,
                                                  batch_size=batch_size, criterion_generator=criterion_generator,
                                                  optimizer_generator=optimizer_generator)
            # torch.save(generator.state_dict(), 'generator_state_dict.pt')
            # torch.save(generator, 'generator.pt')
        print(f'Epoch {epoch}, training generator has been finished.')
        with torch.no_grad():
            generator.eval()
            noise = generate_noise(1)
            fake = generator(noise)
            plt.imshow(fake[0][0].detach().numpy(), cmap='gray')
            plt.show()


3.本次利用MNIST和GAN进行手写数字图像生成的实践总结

(1)自定义Loss

经常我们觉得挺合理的Loss设置其实是不对的,不一定满足网络实际需要。在开始的时候我认为把discriminator的BCELoss换成正样本的 1 − x n 1-x_n 1xn的和加上负样本的 x n x_n xn是挺有道理的,但是事实是网络因此而无法训练。当我把Loss换为自定义的BCELoss,就可以正常训练了。
另外,自定义的Loss中也是可以使用If语句的

(2)网络参数的更新

在生成对抗网络中,训练判别器discriminator时我们需要“固定”生成器generator,将随机噪声通过生成器得到伪造样本作为负样本,和真实数据的正样本一起联合放到discriminator中训练。在这个过程中我们需要用到generator,但不训练generator,因此generator就是“固定的”。同理,在训练generator的时候我们需要用到discriminator但不训练discriminator,因此这时discriminator就是“固定的”。但是这种固定应该怎么体现呢? 开始的时候我觉得,比如固定generator,就应该是在使用generator的时候用 with no_grad()去设置一个无自动梯度的环境并且设置generator.eval()调整到推理模式。这些设置当然对于加速优化是很有好处的,但如果设置不当就可能会有一系列麻烦。在比较不那么在意加速优化的情况下,其实我们可以不设置无自动梯度环境及将generator调整到推理模式,因为虽然在训练discriminator的过程中generator的参数也会有梯度,但是我们在更新参数的时候使用的是optimizer_discriminator.step(),而在设置optimizer_discriminator时是这样设置的

optimizer_discriminator = torch.optim.SGD(discriminator.parameters(), lr=lr)

这也就意味着在调用optimizer_discriminator.step()的时候只有discriminator中的参数受到了反向传播算法的影响而被更新,generator中的参数并未更新。

(3)学习率

刚开始按照习惯我将学习率设置为 0.3 0.3 0.3,但是因为结果一直出不来,我又将学习率设置为了 1 1 1,后面就一直忘改了。但是就是因为这个 1 1 1的学习率出现了很多问题,后来还是调整回了 0.1 0.1 0.1才得以训练成功。就是说学习率是影响训练的一个很重要的因素,容易被忽视。

你可能感兴趣的:(Pytorch,生成对抗网络,深度学习,python)