pytorch-GAN

pytorch-GAN
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(100, 128)
        self.fc11 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.out = nn.Linear(1024, 784)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2, inplace=True)
        x = F.leaky_relu(self.fc11(x), 0.2, inplace=True)
        x = F.leaky_relu(self.fc2(x), 0.2, inplace=True)
        x = F.leaky_relu(self.fc3(x), 0.2, inplace=True)
        x = F.tanh(self.out(x))
        return x


class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.out = nn.Linear(256, 1)

    def forward(self, input):
        x = input.view(input.shape[0], -1)
        x = F.leaky_relu(self.fc1(x), 0.2, inplace=True)
        x = F.leaky_relu(self.fc2(x), 0.2, inplace=True)
        x = F.sigmoid(self.out(x))
        return x


gen = generator()
dis = discriminator()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
gen.to(device)
dis.to(device)
loss = nn.BCELoss()  # 二分类交叉熵
optimizer_G = optim.Adam(gen.parameters(), lr = 0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(dis.parameters(), lr = 0.0002, betas=(0.5, 0.999))

# Configure data loader
os.makedirs('D:/mnist/', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('D:/mnist/', train=True, download=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])), batch_size=64, shuffle=True)

print("finish load dataset")
for epoch in range(200):
    for i, (img, _) in enumerate(dataloader):
        batch = img.size(0)

        valid = torch.ones(batch, 1, dtype=torch.float, requires_grad=False)
        fake = torch.zeros(batch, 0, dtype=torch.float, requires_grad=False)
        real_imgs = torch.Tensor(img)  # (64,1,28,28)
        real_imgs = real_imgs.to(device)
        # train G
        optimizer_G.zero_grad()
        z = torch.Tensor(np.random.normal(0, 1, (batch, 100)))  # 随机生成
        gen_imgs = gen(z.cuda())
        g_loss = loss(dis(gen_imgs.cuda()), valid.cuda())  # 生成的用来欺骗分类器
        g_loss.backward(retain_graph=True)
        optimizer_G.step()

        # train D
        optimizer_D.zero_grad()
        real_loss = loss(dis(real_imgs.cuda()), valid.cuda())
        fake_loss = loss(dis(gen_imgs.cuda()), fake.cuda())
        d_loss = torch.add(real_loss, fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        if i % 300 == 0:
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, 200, i, len(dataloader),
                                                                             d_loss.item(), g_loss.item()))
        batch_done = epoch * len(dataloader) + i
        if batch_done % 400 == 0:
            save_image(gen_imgs.data[:25].view(25, 1, 28, 28), 'D:/mnist/images/%d.png' % batch_done, nrow=5,
                       normalize=True)

你可能感兴趣的:(pytorch-GAN)