SAGAN 代码阅读笔记

论文《Self-Attention Generative Adversarial Networks》
地址:https://arxiv.org/abs/1805.08318
代码地址:https://github.com/heykeetae/Self-Attention-GAN

按照代码流程进行记录

默认参数设置

adv_loss         = 'hinge'
attn_path        = './attn'
batch_size       = 64
beta1            = 0.0
beta2            = 0.9
d_conv_dim       = 64
d_iters          = 5
d_lr             = 0.0004
dataset          = 'celeb'
g_conv_dim       = 64
g_lr             = 0.0001
g_num            = 5
image_path       = './data'
imsize           = 64
lambda_gp        = 10
log_path         = './logs'
log_step         = 10
lr_decay         = 0.95
model            = 'sagan'
model_save_path  = './models'
model_save_step  = 1.0
num_workers      = 2
parallel         = False
pretrained_model = None
sample_path      = './samples'
sample_step      = 100
total_step       = 1000000
train            = True
use_tensorboard  = False
version          = 'sagan_celeb'
z_dim            = 128

Discriminator网络结构

判别器网络设定参数为batch size=64, image_size=64, conv_dim=64

假定输入数据为 torch.Size([64, 3, 64, 64])

# layer1
Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 64, 32, 32])

# layer2
Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 128, 16, 16])

# layer3
Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 256, 8, 8])

可以看出前三层的网络结构基本一致,channel在不断增加,但是尺寸在减小。

前三层结束之后,进行一次 self-attention 层,此时尺寸不变,还是 torch.Size([64, 256, 8, 8]) 注意力map为 torch.Size([64, 64, 64])

如果输入图像数据的尺寸为64时,还有一个layer4,与前三层结构一致

# layer4
Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 512, 4, 4])

第4层结束之后,再进行一次 self-attention,输出第二个注意力maptorch.Size([64, 16, 16])

# last
Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))

此时变为 torch.Size([64, 1, 1, 1])

最后,使用squeeze()从数组的形状中删除单维度条目,即把shape中为1的维度去掉,判别器输出为 torch.Size([64])

Generator网络结构

生成器网络参数设置为 batch_size=64, image_size=64, z_dim=128, conv_dim=64

首先生成一个随机值,每个图像有z_dim维度的噪音组成,假定输入数据为 torch.Size([64, 128])

先将输入数据变为 torch.Size([64, 128, 1, 1])

repeat_num = int(np.log2(self.imsize)) - 3
mult = 2 ** repeat_num # 8

计算mult=8

# layer1
ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1))
SpectralNorm()
BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()

此时变为 torch.Size([64, 512, 4, 4])

# layer2
ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()

此时变为 torch.Size([64, 256, 8, 8])

# layer3
ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()

此时变为 torch.Size([64, 128, 16, 16])

第3层之后,会计算 self-attention,其中map1torch.Size([64, 256, 256])

# layer4
ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
SpectralNorm()
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()

此时变为 torch.Size([64, 64, 32, 32])

第4层之后,也会有attention层,map2torch.Size([64, 1024, 1024])

# last
ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
Tanh()

此时变为 torch.Size([64, 3, 64, 64])

损失函数计算

Discriminator

判别器整体的损失函数是
L D = − E ( x , y ) ∼ p d a t a [ m i n ( 0 , − 1 + D ( x , y ) ) ] − E z ∼ p z , y ∼ p d a t a [ m i n ( 0 , − 1 − D ( G ( z ) , y ) ) ] L_D = -E_{(x,y)\sim p_{data}}[min(0, -1 + D(x,y))] - E_{z \sim p_{z},y \sim p_{data}}[min(0, -1 - D(G(z),y))] LD=E(x,y)pdata[min(0,1+D(x,y))]Ezpz,ypdata[min(0,1D(G(z),y))]

  1. 输入真实图像

    d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
    
  2. 输入生成图像

    d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
    

Generator

生成器整体的损失函数是
L G = − E z ∼ p z , y ∼ p d a t a D ( G ( z ) , y ) L_G=-E_{z \sim p_{z},y \sim p_{data}}D(G(z),y) LG=Ezpz,ypdataD(G(z),y)

fake_images,_,_ = self.G(z)
g_out_fake,_,_ = self.D(fake_images)  # batch x n
g_loss_fake = - g_out_fake.mean()

也就是说,生成器的损失是判别器对生成图像判别的平均值

总结

  • 生成器和判别器中使用了两层self-attention

  • 生成器中使用光谱归一化之后,又加了一层BatchNorm2d,这个地方没有看明白

  • 学习速率不同,但是学习迭代比例是1:1的

你可能感兴趣的:(GAN,SAGAN,GAN)