NeurIPS-2021 workshop
注意力机制是近些年视觉领域研究的热门方向之一
We aim to utilize the contributing factors of weights for the improvement of attention mechanisms.
However, these works neglect information from the tuned weights from training.
提出 Normalization-based Attention Modul,在 resnet 和 mobilenet 上 验证了其有效性
a NAM module is embedded at the end of each network block
W γ W_{\gamma} Wγ 和 W λ W_{\lambda} Wλ 的计算方法如图 1
作者还对 γ \gamma γ 和 λ \lambda λ 进行了归一化约束
p p p is the penalty that balances g ( γ ) g(\gamma) g(γ) and g ( λ ) g(\lambda) g(λ)
看看作者开源的代码,https://github.com/Christian-lyc/NAM
import torch.nn as nn
import torch
from torch.nn import functional as F
class Channel_Att(nn.Module):
def __init__(self, channels, t=16):
super(Channel_Att, self).__init__()
self.channels = channels
self.bn2 = nn.BatchNorm2d(self.channels, affine=True)
def forward(self, x):
residual = x
x = self.bn2(x)
weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())
x = x.permute(0, 2, 3, 1).contiguous()
x = torch.mul(weight_bn, x)
x = x.permute(0, 3, 1, 2).contiguous()
x = torch.sigmoid(x) * residual #
return x
class Att(nn.Module):
def __init__(self, channels,shape, out_channels=None, no_spatial=True):
super(Att, self).__init__()
self.Channel_Att = Channel_Att(channels)
def forward(self, x):
x_out1=self.Channel_Att(x)
return x_out1
仅有 channel normalization-based attention 的部分
CIFAR-100
ImageNet
top1 and top5
单加 channel NAM 比单加 spatial 的要好
提升不是特别的明显,优势在于基本没有引入额外的参数量,下面具体看看参数量
乘以 4,仅看作者开源的代码的话,应该是乘以 2,也就是 BN 的参数量
文章篇幅较短,细节未可知,eg: pixel normalization 的具体实现