import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
class Generator(nn.Module):
def __init__(self, image_size=320, latent_dim=100, output_channel=1):
"""
image_size: image with and height
latent dim: the dimension of random noise z
output_channel: the channel of generated image, for example, 1 for gray image, 3 for RGB image
"""
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.output_channel = output_channel
self.image_size = image_size
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, output_channel * image_size * image_size),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), self.output_channel, self.image_size, self.image_size)
save_image(img, "/pic/111.png", nrow=2, normalize=True)
return img
class Discriminator(nn.Module):
def __init__(self, image_size=320, input_channel=1):
"""
image_size: image with and height
input_channel: the channel of input image, for example, 1 for gray image, 3 for RGB image
"""
super(Discriminator, self).__init__()
self.image_size = image_size
self.input_channel = input_channel
self.model = nn.Sequential(
nn.Linear(input_channel * image_size * image_size, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
out = self.model(img_flat)
return out
def load_mnist_data():
"""
load mnist(0,1,2) dataset
"""
transform = torchvision.transforms.Compose([
transforms.Grayscale(1),
transforms.Resize(320),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, ),
std=(0.5, ))
])
train_dataset = torchvision.datasets.ImageFolder(root='/pic', transform=transform)
return train_dataset
def load_furniture_data():
"""
load furniture dataset
"""
transform = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.ImageFolder(root='/pic', transform=transform)
return train_dataset
def train(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device, z_dim):
"""
train a GAN with model G and D in one epoch
Args:
trainloader: data loader to train
G: model Generator
D: model Discriminator
G_optimizer: optimizer of G(etc. Adam, SGD)
D_optimizer: optimizer of D(etc. Adam, SGD)
loss_func: loss function to train G and D. For example, Binary Cross Entropy(BCE) loss function
device: cpu or cuda device
z_dim: the dimension of random noise z
"""
D.train()
G.train()
D_total_loss = 0
G_total_loss = 0
for i, (x, _) in enumerate(trainloader):
y_real = torch.ones(x.size(0), 1).to(device)
y_fake = torch.zeros(x.size(0), 1).to(device)
x = x.to(device)
z = torch.rand(x.size(0), z_dim).to(device)
D_optimizer.zero_grad()
d_real = D(x)
d_real_loss = loss_func(d_real, y_real)
g_z = G(z)
d_fake = D(g_z)
d_fake_loss = loss_func(d_fake, y_fake)
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
D_optimizer.step()
G_optimizer.zero_grad()
g_z = G(z)
d_fake = D(g_z)
g_loss = loss_func(d_fake, y_real)
g_loss.backward()
G_optimizer.step()
D_total_loss += d_loss.item()
G_total_loss += g_loss.item()
return D_total_loss / len(trainloader), G_total_loss / len(trainloader)
def run_gan(trainloader, G, D, G_optimizer, D_optimizer, loss_func, n_epochs, device, latent_dim):
d_loss_hist = []
g_loss_hist = []
for epoch in range(n_epochs):
d_loss, g_loss = train(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device,
z_dim=latent_dim)
print('Epoch {}: Train D loss: {:.4f}, G loss: {:.4f}'.format(epoch, d_loss, g_loss))
d_loss_hist.append(d_loss)
g_loss_hist.append(g_loss)
return d_loss_hist, g_loss_hist
latent_dim = 100
image_size=320
image_channel=1
learning_rate = 0.0002
betas = (0.5, 0.999)
n_epochs = 100
batch_size = 9
device = torch.device('cpu')
train_dataset = load_mnist_data()
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
bceloss = nn.BCELoss().to(device)
G = Generator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel).to(device)
D = Discriminator(image_size=image_size, input_channel=image_channel).to(device)
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
d_loss_hist, g_loss_hist = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,
n_epochs, device, latent_dim)
test_input = torch.rand(2,100).to(device)
prediction = G(test_input)
print(prediction.shape)
save_image(prediction, "/pic/1.png", nrow=2, normalize=True)