pytorch中一个mnist数据集上的DCGAN示例

环境: python 3.7 + pytorch 1.0.1

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

def init_weight(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0., 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1., 0.02)
        m.bias.data.fill_(0.)

class DCGenerator(nn.Module):
    def __init__(self, convs):
        super(DCGenerator, self).__init__()
        self.convs = nn.ModuleList()
        in_channels = 1
        for i, (out_channels, kernel_size, stride, padding) in enumerate(convs):
            self.convs.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
            if i < len(convs)-1:
                # we use BN and RELU for each layer except the output
                self.convs.append(nn.BatchNorm2d(out_channels))
                self.convs.append(nn.ReLU())
            else:
                # in output, we use Tanh to generate data in [-1, 1]
                self.convs.append(nn.Tanh())
            in_channels = out_channels
        self.apply(init_weight)

    def forward(self, input):
        out = input
        for module in self.convs:
            out = module(out)

        return out


class Discriminator(nn.Module):
    def __init__(self, convs):
        super(Discriminator, self).__init__()
        self.convs = nn.ModuleList()
        in_channels = 1
        for i, (out_channels, kernel_size, stride, padding) in enumerate(convs):
            self.convs.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
            if i != 0 and i != len(convs)-1:
                # we donot use BN in the input layer of D
                self.convs.append(nn.BatchNorm2d(out_channels))
            if i != len(convs)-1:
                self.convs.append(nn.LeakyReLU(0.2))
                in_channels = out_channels
        #self.cls = nn.Linear(out_channels*in_width*in_height, nout)
        self.apply(init_weight)

    def forward(self, input):
        out = input

        for layer in self.convs:
            out = layer(out)
        out = out.view(out.size(0), -1)
        out = F.sigmoid(out)
        return out

train.py

from __future__ import print_function
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from model import *

def sample_noise(batch_size, channels):
    return torch.randn(batch_size, channels, 1, 1).float()

max_iter = 25

download = True

trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize([0.5,], [0.5,])])

mnist = datasets.MNIST('./', train=True, transform=trans, download=download)

batch_size = 64

use_cuda = True

if __name__ == '__main__':
    d_convs = [(32, 4, 2, 1), (64, 4, 2, 1), (1, 7, 1, 0)]
    discriminator = Discriminator(d_convs)
    g_convs = [(64, 7, 1, 0), (32, 4, 2, 1), (1, 4, 2, 1)]
    generator = DCGenerator(g_convs)
    print(discriminator)
    print(generator)

    if use_cuda:
        discriminator, generator = discriminator.cuda(), generator.cuda()

    dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))


    real_label, fake_label = 1, 0

    criterion = nn.BCELoss()

    if use_cuda:
        criterion = criterion.cuda()

    fixed_noise = sample_noise(batch_size, 1)
    if use_cuda:
        fixed_noise = fixed_noise.cuda()
    fixed_noise = Variable(fixed_noise, volatile=True)

    for epoch in range(1, max_iter+1):
        for i, (x, _) in enumerate(dataloader):
            batch_size = x.size(0)
            # training D on real data
            optimizer_d.zero_grad()
            x = Variable(x)
            if use_cuda:
                x = x.cuda()
            output = discriminator(x)
            real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float())
            if use_cuda:
                real_v = real_v.cuda()
            loss_d = criterion(output, real_v)
            loss_d.backward()
            Dx = output.data.mean(dim=0)[0]
            # training D on fake data
            z = sample_noise(batch_size, 1)
            z = Variable(z)
            if use_cuda:
                z = z.cuda()

            fake = generator(z)
            output = discriminator(fake.detach())
            fake_v = Variable(torch.Tensor(batch_size).fill_(fake_label).float())
            if use_cuda:
                fake_v = fake_v.cuda()
            loss_g = criterion(output, fake_v)
            loss_g.backward()
            optimizer_d.step()

            err_D = loss_d.data + loss_g.data

            # training G
            optimizer_g.zero_grad()
            output = discriminator(fake)
            real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float())
            if use_cuda:
                real_v = real_v.cuda()

            loss = criterion(output, real_v)
            loss.backward()
            optimizer_g.step()
            err_G = loss.data
            DGz = output.data.mean(dim=0)[0]

            print('[{:02d}/{:02d}],[{:03d}/{:03d}], errD: {:.4f}, D(x): {:.4f}, errG: {:.4f}, D(G(z)): {:.4f}'.format(
                  epoch, max_iter, i, len(dataloader), err_D, Dx, err_G, DGz))

        fake = generator(fixed_noise)
        
        save_image(fake.data, './mnist-fake-{:02d}.png'.format(epoch),
                   normalize=True)

你可能感兴趣的:(机器学习)