Generaing Digits with Pytorch

In this blog post we'll implement a generative image model that converts random noise into images of digits! The full code is available here, just clone it to your machine and it's ready to play. As a former Torch7 user, I attempt to reproduce the results from the Torch7 post.

For this, we employ Generative Adversarial Network. A GAN consists of two components; a generator which converts random noise into images and a discriminator which tries to distinguish between generated and real images. Here, ‘real’ means that the image came from our training set of images in contrast to the generated fakes.

To train the model we let the discriminator and generator play a game against each other. We first show the discriminator a mixed batch of real images from our training set and of fake images generated by the generator. We then simultaneously optimize the discriminator to answer NO to fake images and YES to real images and optimize the generator to fool the discriminator into believing that the fake images were real. This corresponds to minimizing the classification error wrt. the discriminator and maximizing it wrt. the generator. With careful optimization both generator and discriminator will improve and the generator will eventually start generating convincing images.

Implementing a GAN

We implement the generator and discriminator as convnets and train them with stochastic gradient descent.

The discriminator is a mlp with consecutive blocks of Linear Layer and LeakyReLU activation.

D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

This is a pretty standard architecture. The 28x28 grey images of digits are converted into a 781x1 vector by stacking their columns. The discriminator takes a 784x1 vector as input and predicts YES or NO with a single sigmoid output.

The generator is a mlp with a vector with Linear Layer and ReLU activation repeatedly:

G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

To generate an image we feed the generator with noise distributed N(0,1). After successful training, the output should be meaningful images!

z = torch.randn(batch_size, latent_size).cuda()
z = Variable(z)
fake_images = G(z)

Generating digits

We train our GAN using images of digits from the MNIST dataset. After at around 5 epochs you should start to see blurr digits. And after 80 epochs the results look pleasant.

after 5 epochs

after 100 epochs

Loss and Discriminator Accuracy

accuracy.png

loss.png

We also record the accuracy of the discriminator and the loss of the discriminator and the generator after each epoch. Refer to GAN, when the global minimum of the training criterion is achieved, the loss should be -log4 and the accuracy should be 0.5. Our model does not achieve the ideal result. Moreover, as we can see from the figure, our model start to overfit at around 80 epochs. The model structure and the training strategy should be improved in the future.

你可能感兴趣的:(Generaing Digits with Pytorch)