论文连接:Selective Kernel Networks
SENet主要受神经学的启发,视觉皮层中神经元可以根据不同的刺激动态的调整自身的RF(receptive field,感受野)。其实在之前的很多网络中都采用了多个不同大小kernel,然后将他们提取到的特征进行融合,比如Inception采用了 3 × 3 3 \times 3 3×3, 5 × 5 5 \times 5 5×5 , 7 × 7 7 \times 7 7×7的kernel。
论文提出的一种使用非线性的方法来整合多个kernel提取的信息。作者提出了“Selective Kernel”(SK)卷积,其包含Split,Fuse和Select操作。
如上图所示,采用两个不同的卷积核分别对 X X X进行卷积操作,得到 U ^ \hat{U} U^ 和 U ~ \tilde{U} U~,两者的大小( H × W × C H \times W \times C H×W×C)相同.
1. U ^ \hat{U} U^ 和 U ~ \tilde{U} U~对应元素相加,得到 U ( H × W × C ) U(H \times W \times C) U(H×W×C)。
U = U ^ + U ~ U = \hat{U}+\tilde{U} U=U^+U~
2.对 U U U进行global average pooling,得到 S S S, S S S和 U U U的通道数一样,均为 C C C
3.使用全连接层对 S S S进行压缩,得到 Z Z Z, Z Z Z的通道数为 d d d.
z = F f c ( s ) = δ ( B ( W s ) ) z = F_{fc}(s)=\delta(B(\bm{W}s)) z=Ffc(s)=δ(B(Ws))
δ \delta δ为ReLU函数, B B B为Batch Normalization。
d = m a x ( C / r , L ) d = max(C/r,L) d=max(C/r,L)
r r r为缩放比例, L L L为通道的最小值,论文中设置为32.
a c = e A c z e A c z + e B c z , b c = e B c z e A c z + e B c z a_{c}=\frac{e^{\mathbf{A}_{c} \mathbf{z}}}{e^{\mathbf{A}_{c} \mathbf{z}}+e^{\mathbf{B}_{c} \mathbf{z}}}, b_{c}=\frac{e^{\mathbf{B}_{c} \mathbf{z}}}{e^{\mathbf{A}_{c} \mathbf{z}}+e^{\mathbf{B}_{c} \mathbf{z}}} ac=eAcz+eBczeAcz,bc=eAcz+eBczeBcz
其中 A , B ∈ R C × d \mathbf{A}, \mathbf{B} \in \mathbb{R}^{C \times d} A,B∈RC×d, a , b \mathbf{a},\mathbf{b} a,b分别代表 U ^ \hat{U} U^和 U ~ \tilde{U} U~的soft attention vector。
在这里 a c + b c = 1 a_c+b_c=1 ac+bc=1。
V c = a c ⋅ U c ^ + b c ⋅ U c ~ V_c = a_c \cdot \hat{U_c} + b_c \cdot{\tilde{U_c}} Vc=ac⋅Uc^+bc⋅Uc~
V = [ V 1 , V 2 , … , V C ] \mathbf{V} = [\mathbf{V_1},\mathbf{V_2},\dots,\mathbf{V_C} ] V=[V1,V2,…,VC]
SKNet由多个SK Unit组合而成,每个SK Unit包含 1 × 1 1 \times 1 1×1 卷积,SK 卷积, 1 × 1 1 \times 1 1×1 卷积。
import torch
from torch import nn
class SKConv(nn.Module):
def __init__(self, features, M, G, r, stride=1, L=32):
:param features: input channel dimensionality
:param WH: input spatial dimensionality, used for GAP kernel size.
:param M:the number of branchs.
:param G:number of convolution group
:param r:the radio for compute d, the length of z
:param stride: stride, default 1.
:param L:the minimum dim of vector z in paper, default 32.
super(SKConv, self).__init__()
d = max(int(features/r), L)
self.M = M
self.features = features
self.convs = nn.ModuleList([])
for i in range(self.M):
nn.Conv2d(features, features, kernel_size=3+i*2, stride=stride, padding=1+i, groups=G),
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(features, d)
self.fcs = nn.ModuleList([])
for i in range(M):
nn.Linear(d, features)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
for i, conv in enumerate(self.convs):
fea = conv(x).unsqueeze_(dim=1)
if i == 0:
feas = fea
feas = torch.cat([feas, fea], dim=1)
fea_U = torch.sum(feas, dim=1)
fea_s = self.gap(fea_U).squeeze_(-1).squeeze_(-1)
fea_z = self.fc(fea_s)
for i, fc in enumerate(self.fcs):
vector = fc(fea_z).unsqueeze_(dim=1)
if i == 0:
attention_vectors =vector
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
attention_vectors = self.softmax(attention_vectors)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
fea_v = (feas * attention_vectors).sum(dim=1)
return fea_v
if __name__ == '__main__':
x = torch.rand(8, 64, 32, 32)
conv = SKConv(64,3,8,2)
out = conv(x)
PyTorch中的 XXX_ 和 XXX 实现的功能都是相同的,唯一不同的是前者进行的是 in_place 操作。