pytorch实现dcgan_PyTorch版本DCGAN实现的注解

PyTorch版本DCGAN实现的注解

该篇博文是对PyTorch官方Examples中DCGAN(Deep Convolution Generative Adversarial Networks)实现过程中的一些细节要点的注解

首先是对该脚本运行参数的一些说明:

—dataset 指定训练数据集

—dataroot 指定数据集下载路径或者已经存在的数据集路径

—workers DataLoader进行数据预处理及数据加载使用进程数

—batchSize 一次batch进入模型的图片数目

—imageSize 原始图片重采样进入模型前的大小

—nz 初始噪音向量的大小(Size of latent zz vector)

—ngf 生成网络中基础feature数目(随着网络层数增加,feature数目翻倍)

—ndf 判别网络中基础feature数目 (随着网络层数增加,feature数据翻倍)

—niter 网络训练过程中epoch数目

—lr 初始学习率

—beta1 使用Adam优化算法中的β1β

1参数值

—cuda 指定使用GPU进行训练

—netG 指定生成网络的检查点文件(保存的生成网络的权值文件)

—netD 指定判别网络的检查点文件(保存的判别网络的权值文件)

—outf 模型输出图片以及检查点文件的保存路径

—manualSeed 指定生成随机数的seed

下面说一下我自己认为比较重要的一些细节点:

数据预处理

我以CIFAR10数据集举例,原始数据是32×3232

×

32的图片,在数据加载过程中,我们使用transforms.Compose()将一系列预处理变换进行组合,transforms.Resize()将图片重采样到指定大小,transforms.ToTensor()将PIL Image或者numpy.ndarray对象转为tensor并且将原来[0, 255]的取值范围,变为[0.0, 1.0]的取值范围。最后,transforms.Normalize()将图像范围转变为[-1, 1].

transforms.Normalize()的第一个参数是平均值,第二个参数是标准差,图像有几个channel,给定的平均值和标准差应该有多大的Size. transforms.Normalize的计算过程是这样的:input[channel] = (input[channel] - mean[channel]) / std[channel].

这些操作在DataLoader每次迭代过程中计算,而不是直接作用在Dataset原始数据集上面

生成网络

生成网络中主要使用了5个逆卷积层将原始的噪音数据扩展成64×6464

×

64的图片,卷积核的大小是4×44

×

4, feature maps个数的变化依次是ngf×8n

g

f

×

8 -> ngf×4n

g

f

×

4 -> ngf×2n

g

f

×

2 -> ngfn

g

f -> ncn

c

这里需要关注的是stride和padding的确定,除了第一个逆卷积层的stride取值是1以为,其余都是2,因为其余的逆卷积层都会将feature map的尺寸放大2倍

确定了stride以后,padding的确定,我们可以通过公式计算出来(参加官方文档),

记输入Input的大小为(NN, CinC

i

n, HinH

i

n, WinW

i

n),输出output为(NN, CoutC

o

u

t, HoutH

o

u

t, WoutW

o

u

t)

计算公式为为:

Hout=(Hin−1)∗stride[0]−2∗padding[0]+kernel_size[0]+output_padding[0]H

o

u

t

=

(

H

i

n

1

)

s

t

r

i

d

e

[

0

]

2

p

a

d

d

i

n

g

[

0

]

+

k

e

r

n

e

l

_

s

i

z

e

[

0

]

+

o

u

t

p

u

t

_

p

a

d

d

i

n

g

[

0

]

Wout=(Win−1)∗stride[1]−2∗padding[1]+kernel_size[1]+output_padding[1]W

o

u

t

=

(

W

i

n

1

)

s

t

r

i

d

e

[

1

]

2

p

a

d

d

i

n

g

[

1

]

+

k

e

r

n

e

l

_

s

i

z

e

[

1

]

+

o

u

t

p

u

t

_

p

a

d

d

i

n

g

[

1

]

训练过程

训练过程中分为对判别网络的训练和对生成网络的训练

对于判别网络的训练,首先是给了batchSize数目的真实数据输入到判别网络中,进行反向传播,网络权重优化;然后是给了batchSize数目的由生成网络产生的虚假数据输入到判别网络中,进行反向传播,网络权重优化。在这个过程中有一个细节,就是第215行,输入到判别网络中的fake张量使用了detach()方法,该方法使得在训练判别网络的时候,生成网络保持冻结,不会记录用于autograd的operations。

在训练生成网络的时候,使用前面产生的fake张量传递到判别网络中进行反向传播,网络权重优化。

还有一点是236行,个人觉得应该修改为if (i + 1) % 100 == 0:,这样才能保存第100次,200次等的输出结果。原始代码保存的是第1次,第101次…的输出

源码如下:

from __future__ import print_function

import argparse

import os

import random

import torch

import torch.nn as nn

import torch.nn.parallel

import torch.backends.cudnn as cudnn

import torch.optim as optim

import torch.utils.data

import torchvision.datasets as dset

import torchvision.transforms as transforms

import torchvision.utils as vutils

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake')

parser.add_argument('--dataroot', required=True, help='path to dataset')

parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)

parser.add_argument('--batchSize', type=int, default=64, help='input batch size')

parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')

parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')

parser.add_argument('--ngf', type=int, default=64)

parser.add_argument('--ndf', type=int, default=64)

parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')

parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')

parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')

parser.add_argument('--cuda', action='store_true', help='enables cuda')

parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')

parser.add_argument('--netG', default='', help="path to netG (to continue training)")

parser.add_argument('--netD', default='', help="path to netD (to continue training)")

parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')

parser.add_argument('--manualSeed', type=int, help='manual seed')

opt = parser.parse_args()

print(opt)

try:

os.makedirs(opt.outf)

except OSError:

pass

if opt.manualSeed is None:

opt.manualSeed = random.randint(1, 10000)

print("Random Seed: ", opt.manualSeed)

random.seed(opt.manualSeed)

torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:

print("WARNING: You have a CUDA device, so you should probably run with --cuda")

if opt.dataset in ['imagenet', 'folder', 'lfw']:

# folder dataset

dataset = dset.ImageFolder(root=opt.dataroot,

transform=transforms.Compose([

transforms.Resize(opt.imageSize),

transforms.CenterCrop(opt.imageSize),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

]))

elif opt.dataset == 'lsun':

dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],

transform=transforms.Compose([

transforms.Resize(opt.imageSize),

transforms.CenterCrop(opt.imageSize),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

]))

elif opt.dataset == 'cifar10':

dataset = dset.CIFAR10(root=opt.dataroot, download=True,

transform=transforms.Compose([

transforms.Resize(opt.imageSize),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

]))

elif opt.dataset == 'fake':

dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),

transform=transforms.ToTensor())

assert dataset

dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,

shuffle=True, num_workers=int(opt.workers))

device = torch.device("cuda:0" if opt.cuda else "cpu")

ngpu = int(opt.ngpu)

nz = int(opt.nz)

ngf = int(opt.ngf)

ndf = int(opt.ndf)

nc = 3

# custom weights initialization called on netG and netD

def weights_init(m):

classname = m.__class__.__name__

if classname.find('Conv') != -1:

m.weight.data.normal_(0.0, 0.02)

elif classname.find('BatchNorm') != -1:

m.weight.data.normal_(1.0, 0.02)

m.bias.data.fill_(0)

class Generator(nn.Module):

def __init__(self, ngpu):

super(Generator, self).__init__()

self.ngpu = ngpu

self.main = nn.Sequential(

# input is Z, going into a convolution

nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),

nn.BatchNorm2d(ngf * 8),

nn.ReLU(True),

# state size. (ngf*8) x 4 x 4

nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(ngf * 4),

nn.ReLU(True),

# state size. (ngf*4) x 8 x 8

nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),

nn.BatchNorm2d(ngf * 2),

nn.ReLU(True),

# state size. (ngf*2) x 16 x 16

nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),

nn.BatchNorm2d(ngf),

nn.ReLU(True),

# state size. (ngf) x 32 x 32

nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),

nn.Tanh()

# state size. (nc) x 64 x 64

)

def forward(self, input):

if input.is_cuda and self.ngpu > 1:

output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))

else:

output = self.main(input)

return output

netG = Generator(ngpu).to(device)

netG.apply(weights_init)

if opt.netG != '':

netG.load_state_dict(torch.load(opt.netG))

print(netG)

class Discriminator(nn.Module):

def __init__(self, ngpu):

super(Discriminator, self).__init__()

self.ngpu = ngpu

self.main = nn.Sequential(

# input is (nc) x 64 x 64

nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),

nn.LeakyReLU(0.2, inplace=True),

# state size. (ndf) x 32 x 32

nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 2),

nn.LeakyReLU(0.2, inplace=True),

# state size. (ndf*2) x 16 x 16

nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 4),

nn.LeakyReLU(0.2, inplace=True),

# state size. (ndf*4) x 8 x 8

nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 8),

nn.LeakyReLU(0.2, inplace=True),

# state size. (ndf*8) x 4 x 4

nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),

nn.Sigmoid()

)

def forward(self, input):

if input.is_cuda and self.ngpu > 1:

output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))

else:

output = self.main(input)

return output.view(-1, 1).squeeze(1)

netD = Discriminator(ngpu).to(device)

netD.apply(weights_init)

if opt.netD != '':

netD.load_state_dict(torch.load(opt.netD))

print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)

real_label = 1

fake_label = 0

# setup optimizer

optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

for epoch in range(opt.niter):

for i, data in enumerate(dataloader, 0):

############################

# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))

###########################

# train with real

netD.zero_grad()

real_cpu = data[0].to(device)

batch_size = real_cpu.size(0)

label = torch.full((batch_size,), real_label, device=device)

output = netD(real_cpu)

errD_real = criterion(output, label)

errD_real.backward()

D_x = output.mean().item()

# train with fake

noise = torch.randn(batch_size, nz, 1, 1, device=device)

fake = netG(noise)

label.fill_(fake_label)

output = netD(fake.detach())

errD_fake = criterion(output, label)

errD_fake.backward()

D_G_z1 = output.mean().item()

errD = errD_real + errD_fake

optimizerD.step()

############################

# (2) Update G network: maximize log(D(G(z)))

###########################

netG.zero_grad()

label.fill_(real_label) # fake labels are real for generator cost

output = netD(fake)

errG = criterion(output, label)

errG.backward()

D_G_z2 = output.mean().item()

optimizerG.step()

print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'

% (epoch, opt.niter, i, len(dataloader),

errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

if i % 100 == 0:

vutils.save_image(real_cpu,

'%s/real_samples.png' % opt.outf,

normalize=True)

fake = netG(fixed_noise)

vutils.save_image(fake.detach(),

'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),

normalize=True)

# do checkpointing

torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))

torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

你可能感兴趣的:(pytorch实现dcgan)