自注意力 ⇆ 1 × 1 卷积 卷积 自注意力 \overset{1\times 1 卷积 }\leftrightarrows 卷积 自注意力⇆1×1卷积卷积
项目 | 分解阶段1 | 分解阶段2 |
---|---|---|
kernel大小为k×k的传统卷积 | k 2 k^2 k2个单独的 1×1 卷积 | 然后进行移位和求和操作。 |
自注意力 | 查询、键和值的投影解释为多个 1×1 卷积 | 然后通过计算注意力权重和聚合值。 |
具有某种相似 | 此阶段占据更多计算 |
卷积核 | 输入 | 输出 |
---|---|---|
K ∈ R C o u t × C i n × k × k K ∈ R^{C_{out} ×C_{in}×k×k} K∈RCout×Cin×k×k | F ∈ R C i n × H × W F ∈ R^{C_{in}×H×W} F∈RCin×H×W | G ∈ R C o u t × H × W G ∈ R^{C_{out}×H×W} G∈RCout×H×W |
像素表示(位置(i, j)的像素) | f i j ∈ R C i n f_{ij} ∈ R^{C_{in}} fij∈RCin | g i j ∈ R C o u t g_{ij} ∈ R^{C_{out}} gij∈RCout |
注:这好像不是标准的卷积,它的特征不会随着卷积减小。
标准的卷积公式:
g i j = ∑ p , q K p , q f i + p − ⌊ k / 2 ⌋ , j + q − ⌊ k / 2 ⌋ (1) g_{ij}=\sum_{p,q}K_{p,q}f_{i+p-\left\lfloor k/2 \right\rfloor,j+q-\left\lfloor k/2 \right\rfloor} \tag{1} gij=p,q∑Kp,qfi+p−⌊k/2⌋,j+q−⌊k/2⌋(1)
核在位置 ( p , q ) 的权重 : K p , q ∈ R C o u t × C i n , p , q ∈ { 0 , 1 , . . . , k − 1 } 核在位置 (p, q) 的权重:K_{p,q }∈ R^{C_{out}×C_{in}} , p, q ∈ \{0, 1, . . . , k−1\} 核在位置(p,q)的权重:Kp,q∈RCout×Cin,p,q∈{0,1,...,k−1}
式 (1) 重写为式(2)和(3): 不同内核位置的特征映射的和
g i j = ∑ p , q g i j ( p , q ) (2) g_{ij}=\sum_{p,q}g_{ij}^{(p,q)} \tag{2} gij=p,q∑gij(p,q)(2)
g i j ( p , q ) = K p , q f i + p − ⌊ k / 2 ⌋ , j + q − ⌊ k / 2 ⌋ (3) g_{ij}^{(p,q)} = K_{p,q}f_{i+p-\left\lfloor k/2 \right\rfloor,j+q-\left\lfloor k/2 \right\rfloor} \tag{3} gij(p,q)=Kp,qfi+p−⌊k/2⌋,j+q−⌊k/2⌋(3)
f ~ ≜ S h i f t ( f , Δ x , Δ y ) \widetilde f \triangleq Shift(f,\Delta x,\Delta y) f ≜Shift(f,Δx,Δy)
f ~ i , j = f i + Δ x , j + Δ y , ∀ i , j (4) \widetilde f_{i,j} = f_{i+\Delta x,j+\Delta y},\forall i,j \tag{4} f i,j=fi+Δx,j+Δy,∀i,j(4)
如此将卷积写为
g i j ( p , q ) = K p , q f i + p − ⌊ k / 2 ⌋ , j + q − ⌊ k / 2 ⌋ = S h i f t ( K p , q f i , j , p − ⌊ k / 2 ⌋ , q − ⌊ k / 2 ⌋ ) (5) g_{ij}^{(p,q)} = K_{p,q}f_{i+p-\left\lfloor k/2 \right\rfloor,j+q-\left\lfloor k/2 \right\rfloor}= Shift(K_{p,q}f_{i,j} ,p-\left\lfloor k/2 \right\rfloor,q-\left\lfloor k/2 \right\rfloor)\tag{5} gij(p,q)=Kp,qfi+p−⌊k/2⌋,j+q−⌊k/2⌋=Shift(Kp,qfi,j,p−⌊k/2⌋,q−⌊k/2⌋)(5)
标准卷积: | |
---|---|
第一阶段 | 输入 feature map 对于某一位置 (p, q) ,按照核权重,进行线性投影,(标准的 1 × 1 卷积) |
第二阶段 | 特征映射会根据内核位置进行平移,相加聚合。 |
S t a g e I : g ~ i j ( p , q ) = K p , q f i j S t a g e I I : g i j ( p , q ) = S h i f t ( g i j ( p , q ) , p − ⌊ k / 2 ⌋ , q − ⌊ k / 2 ⌋ ) g i j = ∑ p , q g i j ( p , q ) Stage I : \ \ \ \widetilde g_{ij}^{(p,q)} = K_{p,q}f_{ij} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \\ Stage II : g_{ij}^{(p,q)} = Shift( g_{ij}^{(p,q)} ,p-\left\lfloor k/2 \right\rfloor,q-\left\lfloor k/2 \right\rfloor ) \\ g_{ij}=\sum_{p,q} g_{ij}^{(p,q)} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ StageI: g ij(p,q)=Kp,qfij StageII:gij(p,q)=Shift(gij(p,q),p−⌊k/2⌋,q−⌊k/2⌋)gij=p,q∑gij(p,q)
输入 | 输出 | |
---|---|---|
整体表示 | F ∈ R C i n × H × W F ∈ R^{C_{in}×H×W} F∈RCin×H×W | G ∈ R C o u t × H × W G ∈ R^{C_{out}×H×W} G∈RCout×H×W |
像素表示(位置(i, j)的像素) | f i j ∈ R C i n f_{ij} ∈ R^{C_{in}} fij∈RCin | g i j ∈ R C o u t g_{ij} ∈ R^{C_{out}} gij∈RCout |
g i j = N ∥ l = 1 ( ∑ a , b ∈ N k ( i , j ) A ( W q ( l ) f i j , W k ( l ) f a b ) W v ( l ) f a b ) (9) g_{ij}= \begin{array}{c}N\\\parallel\\{l=1}\end{array}(\sum_{a,b\in \mathcal N_k(i,j)}A(W_q^{(l)}f_{ij},W_k^{(l)}f_{ab}) W_v^{(l)}f_{ab}) \tag{9} gij=N∥l=1(a,b∈Nk(i,j)∑A(Wq(l)fij,Wk(l)fab)Wv(l)fab)(9)
∥ \parallel ∥ | concatenation of the outputs of N attentionheads |
---|---|
N k ( i , j ) \mathcal N _k (i, j) Nk(i,j) | 以 (i, j) 为中心,空间宽度为 k 的像素的局部区域 |
A ( W q ( l ) f i j , W k ( l ) f a b ) A(W^{(l)}_q f_{ij} , W^{(l)}_kf_{ab}) A(Wq(l)fij,Wk(l)fab) | 关于 N k ( i , j ) N _k (i, j) Nk(i,j)的权重 |
常用A的计算方法(d是 W q ( l ) f i j W^{(l)}_q f_{ij} Wq(l)fij的维数):
S t a g e I : q i , j ( l ) = W q ( l ) f i j , k i , j ( l ) = W k ( l ) f i j , v i , j ( l ) = W v ( l ) f i j , S t a g e I I : g i j = N ∥ l = 1 ( ∑ a , b ∈ N k ( i , j ) A ( q i , j ( l ) , k i , j ( l ) ) v i , j ( l ) ) Stage I : \ \ \ \ q_{i,j}^{(l)}=W_q^{(l)}f_{ij}, k_{i,j}^{(l)}=W_k^{(l)}f_{ij}, v_{i,j}^{(l)}=W_v^{(l)}f_{ij}, \ \ \ \\ Stage II :g_{ij}= \begin{array}{c}N\\\parallel\\{l=1}\end{array}(\sum_{a,b\in \mathcal N_k(i,j)}A( q_{i,j}^{(l)} , k_{i,j}^{(l)} ) v_{i,j}^{(l)}) StageI: qi,j(l)=Wq(l)fij,ki,j(l)=Wk(l)fij,vi,j(l)=Wv(l)fij, StageII:gij=N∥l=1(a,b∈Nk(i,j)∑A(qi,j(l),ki,j(l))vi,j(l))
综合: | |
---|---|
第一阶段 | 输入特征通过三个 1×1 卷积进行投影,分别重塑成N 个片段。这样,得到了一个包含 3×N feature map 的丰富的中间特征集。 |
第二阶段 | 而在第二阶段,投影的特征映射会根据内核位置进行移动,并最终聚合在一起。 |
F o u t = α F a t t + β F c o n v F_{out} = \alpha F_{att}+\beta F_{conv} Fout=αFatt+βFconv
将张量向各个方向移动实际上打破了数据局部性,并且难以实现向量化实现。这可能会极大地损害本文模块在推理时的实际效率。本文采用固定核的深度卷积来替代无效张量位移,如图所示:
如果将卷积核 (核大小 k = 3) 表示为(15),相应的输出可以表示为(16)和(17)。
如图中的(a)到(b)部分,即用一个卷积来表示位移。
(c)部分 | 解释 |
---|---|
l | 为了进一步结合不同方向特征的总和,本文将所有输入特征和卷积核分别串联起来,将移位操作表示为单个组卷积 |
ll | 卷积核为可学习权值,平移核为初始化 |
lll | 多组卷积核来匹配卷积和自注意力路径的输出通道维数 |
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=2)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(288, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(288, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(288, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(288, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=2)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(576, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(576, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(576, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(576, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(576, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(576, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 128, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=2)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(1152, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 128, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(1152, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 128, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(1152, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): ACmix(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv_p): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
(pad_att): ReflectionPad2d((3, 3, 3, 3))
(unfold): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
(softmax): Softmax(dim=1)
(fc): Conv2d(12, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dep_conv): Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
class ACmix(nn.Module):
def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
super(ACmix, self).__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.head = head
self.kernel_att = kernel_att
self.kernel_conv = kernel_conv
self.stride = stride
self.dilation = dilation
self.rate1 = torch.nn.Parameter(torch.Tensor(1))
self.rate2 = torch.nn.Parameter(torch.Tensor(1))
self.head_dim = self.out_planes // self.head
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)
self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
self.softmax = torch.nn.Softmax(dim=1)
self.fc = nn.Conv2d(3*self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes, kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1, stride=stride)
self.reset_parameters()
def reset_parameters(self):
init_rate_half(self.rate1)
init_rate_half(self.rate2)
kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
for i in range(self.kernel_conv * self.kernel_conv):
kernel[i, i//self.kernel_conv, i%self.kernel_conv] = 1.
kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
self.dep_conv.bias = init_rate_0(self.dep_conv.bias)
def forward(self, x):
q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
scaling = float(self.head_dim) ** -0.5
b, c, h, w = q.shape
h_out, w_out = h//self.stride, w//self.stride
# ### att
# ## positional encoding
pe = self.conv_p(position(h, w, x.is_cuda))
q_att = q.view(b*self.head, self.head_dim, h, w) * scaling
k_att = k.view(b*self.head, self.head_dim, h, w)
v_att = v.view(b*self.head, self.head_dim, h, w)
if self.stride > 1:
q_att = stride(q_att, self.stride)
q_pe = stride(pe, self.stride)
else:
q_pe = pe
unfold_k = self.unfold(self.pad_att(k_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) # b*head, head_dim, k_att^2, h_out, w_out
unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) # 1, head_dim, k_att^2, h_out, w_out
att = (q_att.unsqueeze(2)*(unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
att = self.softmax(att)
out_att = self.unfold(self.pad_att(v_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out)
out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)
## conv
f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h*w), k.view(b, self.head, self.head_dim, h*w), v.view(b, self.head, self.head_dim, h*w)], 1))
f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
out_conv = self.dep_conv(f_conv)
return self.rate1 * out_att + self.rate2 * out_conv
# 其中位置编码用到的函数
def position(H, W, is_cuda=True):
if is_cuda:
loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)
loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)
else:
loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
return loc
Layer 1 x [2, 64, 56, 56] q [2, 64, 56, 56] k [2, 64, 56, 56] v [2, 64, 56, 56] scaling = float(self.head_dim) ** -0.5 0.25 pe = self.conv_p(position(h, w, x.is_cuda)) 通过对线性数据的卷积得到位置编码pe [1, 16, 56, 56] attention中的放缩因子 q_att [8, 16, 56, 56] view k_att view view(b*self.head, self.head_dim, h, w) [8, 16, 56, 56] view v_att view [8, 16, 56, 56] [8, 16, 62, 62] pad [8, 784, 3136] Unfold(kernel_size=7, dilation=1, padding=0, stride=1) [8, 16, 49, 56, 56] view b*head, head_dim, k_att^2, h_out, w_out unfold_rpe [1, 16, 49, 56, 56] att = (q_att.unsqueeze(2)*(unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) q_pe copy 与左侧k_att类似的unfold处理 [8, 49, 56, 56] softmax unfold_k out_att att 与左侧k_att类似的unfold处理 out_att (att.unsqueeze(1) * out_att).sum(2) out_att view 1×1conv 1×1conv 1×1conv f_all [2, 9, 16, 3136] 由q,k,v view,cat,通过fc获得 f_conv permute+reshape Conv2d(144, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False) dep_conv out_conv 加权求和 out [2, 64, 56, 56] [2, 64, 56, 56] [【cvpr2022】自注意力和卷积集成!ACmix性能速度全面提升!](https://zhuanlan.zhihu.com/p/490226994): ![在这里插入图片描述](https://img-blog.csdnimg.cn/94afb9ec2cb940bd9b13ec8ced5379b3.png)
pdf
code
卷积层的变体和替代
在Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions中提出了Shfit操作, 这一操作基本上可以看成具有特殊设计的固定的权重的中Conv. 当我们将DepthwiseConv中的每一个卷积核改造为只有一处取1, 其它处取0的不可更新的核时, 我们就可以进行Shift操作了. 在直观上, Shift操作先对对每一个特征图进行了平移, 再通过1x1卷积提取信息. 在实现上, 可以直接对内存进行操作, 因而是Zero Flop.
https://zhuanlan.zhihu.com/p/530428399
https://zhuanlan.zhihu.com/p/458016349
如何理解Inductive bias