本教程教你如何使用dcgan
训练mnist
数据集,生成手写数字。
请参考官网安装。
请参考官网安装。
pip install jupyter
pip install matplotlib
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import utils, datasets, transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
设置随机种子,以便复现实验结果。
torch.manual_seed(0)
# Root directory for dataset
dataroot = "data/mnist"
# Number of workers for dataloader
workers = 10
# Batch size during training
batch_size = 100
# Spatial size of training images. All images will be resized to this size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 1
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 10
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
使用mnist
数据集,其中训练集6万张,测试集1万张,我们这里不是分类任务,而是使用dcgan
的生成任务,所以就不分训练和测试了,全部图像都可以利用。
train_data = datasets.MNIST(
root=dataroot,
train=True,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]),
download=True
)
test_data = datasets.MNIST(
root=dataroot,
train=False,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
)
dataset = train_data+test_data
print(f'Total Size of Dataset: {len(dataset)}')
输出:
Total Size of Dataset: 70000
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers
)
检测cuda
是否可用,可用就用cuda
加速,否则使用cpu
训练。
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')
inputs = next(iter(dataloader))[0]
plt.figure(figsize=(10,10))
plt.title("Training Images")
plt.axis('off')
inputs = utils.make_grid(inputs[:100]*0.5+0.5, nrow=10)
plt.imshow(inputs.permute(1, 2, 0))
在dcgan
论文中,作者指出所有模型权重应当从均值为0,标准差为0.02的正态分布中随机初始化。
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if device.type == 'cuda' and ngpu > 1:
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.2.
# netG.apply(weights_init)
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
# state size. (1) x 1 x 1
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if device.type == 'cuda' and ngpu > 1:
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights to mean=0, stdev=0.2.
netD.apply(weights_init)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(100, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
D_x_list = []
D_z_list = []
loss_tep = 10
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
beg_time = time.time()
# For each batch in the dataloader
for i, data in enumerate(dataloader):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
end_time = time.time()
run_time = round(end_time-beg_time)
print(
f'Epoch: [{epoch+1:0>{len(str(num_epochs))}}/{num_epochs}]',
f'Step: [{i+1:0>{len(str(len(dataloader)))}}/{len(dataloader)}]',
f'Loss-D: {errD.item():.4f}',
f'Loss-G: {errG.item():.4f}',
f'D(x): {D_x:.4f}',
f'D(G(z)): [{D_G_z1:.4f}/{D_G_z2:.4f}]',
f'Time: {run_time}s',
end='\r'
)
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Save D(X) and D(G(z)) for plotting later
D_x_list.append(D_x)
D_z_list.append(D_G_z2)
# Save the Best Model
if errG < loss_tep:
torch.save(netG.state_dict(), 'model.pt')
loss_tep = errG
# Check how the generator is doing by saving G's output on fixed_noise
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(utils.make_grid(fake*0.5+0.5, nrow=10))
print()
输出:
Starting Training Loop...
Epoch: [01/10] Step: [700/700] Loss-D: 0.6744 Loss-G: 1.0026 D(x): 0.5789 D(G(z)): [0.0638/0.4204] Time: 114s
Epoch: [02/10] Step: [700/700] Loss-D: 2.2584 Loss-G: 3.7674 D(x): 0.9661 D(G(z)): [0.8334/0.0440] Time: 166s
Epoch: [03/10] Step: [700/700] Loss-D: 1.2438 Loss-G: 0.8505 D(x): 0.4126 D(G(z)): [0.1107/0.4717] Time: 166s
Epoch: [04/10] Step: [700/700] Loss-D: 0.3479 Loss-G: 2.5261 D(x): 0.8771 D(G(z)): [0.1796/0.0980] Time: 166s
Epoch: [05/10] Step: [700/700] Loss-D: 0.6771 Loss-G: 3.8938 D(x): 0.9139 D(G(z)): [0.3889/0.0277] Time: 161s
Epoch: [06/10] Step: [700/700] Loss-D: 0.2697 Loss-G: 3.8211 D(x): 0.9490 D(G(z)): [0.1823/0.0282] Time: 166s
Epoch: [07/10] Step: [700/700] Loss-D: 0.2874 Loss-G: 2.1176 D(x): 0.8062 D(G(z)): [0.0494/0.1503] Time: 180s
Epoch: [08/10] Step: [700/700] Loss-D: 0.7798 Loss-G: 1.6315 D(x): 0.5978 D(G(z)): [0.1508/0.2463] Time: 171s
Epoch: [09/10] Step: [700/700] Loss-D: 0.3052 Loss-G: 0.8984 D(x): 0.7611 D(G(z)): [0.0023/0.4799] Time: 165s
Epoch: [10/10] Step: [700/700] Loss-D: 1.1115 Loss-G: 2.2473 D(x): 0.7334 D(G(z)): [0.4824/0.1385] Time: 157s
plt.figure(figsize=(20, 10))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses[::100], label="G")
plt.plot(D_losses[::100], label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.axhline(y=0, label="0", c="g") # asymptote
plt.legend()
plt.figure(figsize=(20, 10))
plt.title("D(x) and D(G(z)) During Training")
plt.plot(D_x_list[::100], label="D(x)")
plt.plot(D_z_list[::100], label="D(G(z))")
plt.xlabel("iterations")
plt.ylabel("Probability")
plt.axhline(y=0.5, label="0.5", c="g") # asymptote
plt.legend()
fig = plt.figure(figsize=(10, 10))
plt.axis("off")
ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
# Size of the Figure
plt.figure(figsize=(20,10))
# Plot the real images
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
real = next(iter(dataloader))
plt.imshow(utils.make_grid(real[0][:100]*0.5+0.5, nrow=10).permute(1, 2, 0))
# Load the Best Generative Model
netG = Generator(0)
netG.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
netG.eval()
# Generate the Fake Images
with torch.no_grad():
fake = netG(fixed_noise.cpu())
# Plot the fake images
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
fake = utils.make_grid(fake*0.5+0.5, nrow=10)
plt.imshow(fake.permute(1, 2, 0))
# Save the comparation result
plt.savefig('result.jpg', bbox_inches='tight')
https://github.com/XavierJiezou/pytorch-dcgan-mnist
dcgan论文:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
https://blog.csdn.net/qq_42951560/article/details/110308336
代码运行报错或其它问题可以联系我的微信号:wxhghgxj