class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
def forward(self, input):
"""Standard forward"""
return self.model(input)
Unet的模型结构如下图示,因此是从最内层开始搭建:
经过第一行后,网络结构如下,也就是最内层的下采样->上采样。
之后有一个循环,经过第一次循环后,在上一层的外围再次搭建了下采样和上采样:
经过第二次循环:
经过第三次循环:
可以看到每次反卷积的输入特征图的channel是1024,是因为它除了要接受上一层反卷积的输出(512维度),还要接受与其特征图大小相同的下采样层的输出(512维度),因此是1024的维度数。
循环完毕后,再次添加四次外部的降采样和反卷积,最终的网络结构如下:
UnetGenerator(
(model): UnetSkipConnectionBlock(
(model): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): Tanh()
)
)
)
我们这里假定pix2pix是风格A2B,风格A就是左边的图,风格B是右边的图。
反向传播的代码如下,整个是先更新D再更新G。
(1)首先向前传播,输入A,经过G,得到fakeB;
(2)开始更新D,进入backward_D函数:
(3)开始更新G,进入backward_G函数:
下图就可视化了上述的过程。
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize_parameters(self):
self.forward() # compute fake images: G(A)
# update D
self.set_requires_grad(self.netD, True) # enable backprop for D
self.optimizer_D.zero_grad() # set D's gradients to zero
self.backward_D() # calculate gradients for D
self.optimizer_D.step() # update D's weights
# update G
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
self.optimizer_G.zero_grad() # set G's gradients to zero
self.backward_G() # calculate graidents for G
self.optimizer_G.step() # udpate G's weights
pix2pix还对判别器的结构做了一定的改动。之前都是对整张图像输出一个是否为真实的概率。pix2pix提出了PatchGan的概念。PatchGAN对图片中的每一个N×N的小块(patch)计算概率,然后再将这些概率求平均值作为整体的输出。
在上面的代码中pred_fake = self.netD(fake_AB.detach())的输出就不是一个概率值,而是30×30的特征图,相当于有30×30个patch。
下图表示标准的D网络结构(n_layers = 3),n_layers 为主要的特征卷积层数为3。如何理解?
需要注意一下,patchgan channel维度最大为512。
DataParallel(
(module): NLayerDiscriminator(
(model): Sequential(
(0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
)
)
)
具体代码如下。与我们前面所述的稍微有些不一样,按照前面所述for n in range(1, n_layers)中相当于构建n_layers个特征提取层。但是代码中实际上构建了n_layers-1个,最后一个标准的特征提取层放在了sequence +=[...]中。
但是理解上还是可以按照前面。在spade框架中,就重新了构建patchgan的过程,其中就把最后一个标准的特征提取层也通过for n in range(1, n_layers)构建了。见https://github.com/NVlabs/SPADE/blob/master/models/networks/discriminator.py
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4 #卷积核的大小
padw = 1 #pading
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] #head conv
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # channel = 1
self.model = nn.Sequential(*sequence)
下面这张图是CGAN的示意图。可以看到
pix2pix模型与CGAN最大的不同在于,不再输入噪声z。因为实验中,即便给G输入一个噪声z,G也只学会将其忽略并生成图像,噪声z对输出结果的影响几乎微乎其微。因此为了简洁性,将z去掉了。
pix2pix模型中G的输入实际上等于CGAN模型的条件y。