[论文笔记] Self-Attention Generative Adversarial Networks

问题:卷积只有局部的感受野,大范围的依赖关系只能通过多层卷积进行处理。这可能影响网络学习到长依赖关系:1、小模型可能无法学习;2、优化算法可能很难找到多层卷积的合适参数来捕捉这种依赖关系;3、这种参数化可能对之前没见过的图片很不稳定,容易失败。
单纯增大卷积核扩大感受野是个办法,但增大了计算量


文章贡献
1、SAGAN 中引入:引入attention机制学习long range dependency
[论文笔记] Self-Attention Generative Adversarial Networks_第1张图片
2、生成器、鉴别器均引入 spectral normalization提高训练稳定性
[论文笔记] Self-Attention Generative Adversarial Networks_第2张图片
[论文笔记] Self-Attention Generative Adversarial Networks_第3张图片
3、在ImageNet上进行实验,生成128X128大小的图像,达到SOTA指标,IS为52.52,FID为18.65 。(基于50k随机生成样本计算)
[论文笔记] Self-Attention Generative Adversarial Networks_第4张图片


实验细节
1、128X128大小Imagenet训练集,Adam优化器,hinge损失函数,迭代一百万次取最优指标的模型,指标是在50k随机产生的样本上进行计算得出的。
2、文中使用的稳定GAN训练的技巧:

  • 生成器、鉴别器均引入 spectral normalization,G、D在一次迭代中均优化一次
  • TTUR学习率更新方式
    [论文笔记] Self-Attention Generative Adversarial Networks_第5张图片

attention map 分析
源码分析

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N)

        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        # print(self.gamma)
        out = self.gamma*out + x
        return out,attention

你可能感兴趣的:(论文笔记,GAN)