混合深度卷积和自注意力
论文中提到了的 ViT 的主要限制之一是其令人印象深刻的数据需求。虽然 ViT 在庞大的 JFT300M 数据集上显示出令人兴奋的结果,但它在数据量少的情况下性能仍然不如的经典 CNN。这表明 Transformers 可能缺少 CNN 拥有的泛化能力,因此需要大量数据来弥补。但是与 CNN 相比,注意力模型具有更高的模型容量。
CoAtNet 的目标是将 CNN 和 Transformer 的优点融合到一个单一的架构中,但是混合 CNN 和 Transformer 的正确方法是什么?
第一个想法是利用已经讨论过的 MBConv 块,它采用具有倒置残差的深度卷积,这种扩展压缩方案与 Transformer 的 FFN 模块相同。除了这种相似性之外,depthwise convolution 和 self-attention 都可以表示为一个预定义的感受野中每个维度的加权值之和。其中深度卷积可以表示为:
其中 xᵢ 和 yᵢ 分别是位置 i 的输入和输出, wᵢ ₋ ⱼ 是位置 (i - j) 的权重矩阵, L (i) 分别是 i. 通道的局部邻域。
相比之下,self-attention 允许感受野不是局部邻域,并基于成对相似性计算权重,然后是 softmax 函数:
其中 G 表示全局空间,xᵢ, xⱼ 是两对(例如图像的两个patch)。为了便于理解一个简化的版本(省略了多头 Q、K 和 V 投影),将每个patch与同一图像中的每个其他patch进行比较,以产生一个自注意力矩阵。
让我们尝试分析这两个公式的优缺点:
Input-Adaptive Weighting:矩阵 wᵢ an 是一个与输入无关的静态值,而注意力权重 Aᵢⱼ 取决于输入的表示。这使得 self-attention 更容易捕获输入中不同元素之间的关系,但代价是在数据有限时存在过度拟合的风险。
Translation Equivariance:卷积权重 wᵢ ⱼ ⱼ 关心的是 i 和 j 之间的相对偏移,而不是 i 和 j 的具体值。这种平移不变性可以在有限大小的数据集下提高泛化能力。
Global Receptive Field:与 CNN 的局部感受野相比,self-attention 中使用的更大感受野提供了更多的上下文信息。
综上所述,最优架构应该是自注意力的输入+自适应加权和全局感受野特性+ CNN 的平移不变性。所以作者提出的想法是在softmax初始化之后或之前将全局静态卷积核与自适应注意力矩阵相加:
有了上面的理论基础,下一步就是弄清楚如何堆叠卷积和注意力块。作者决定只有在特征图小到可以处理之后才使用卷积来执行下采样和全局相对注意力操作。并且执行下采样方式也有两种 :
像在 ViT 模型中一样将图像划分为块,并堆叠相关的自注意力块。该模型被用作与原始 ViT 的比较。
使用渐进池化的多阶段操作。这种方法分为5个阶段,但是前两个阶段,即经典的卷积层和用于降低维度的MBConv块。为了简单起见这里将其合并为一个阶段命名为S0。后面三个阶段可以是卷积或Transformer块,产生 4 种组合:S0-CCC、S0-CCT、S0-CTT 和 S0-TTT
这样产生的 5 个模型在泛化方面(训练损失和评估准确度之间的差距)和使用 1.3M 图像、超过 3B 图像的模型容量(拟合大型训练数据集的能力)进行了比较。
泛化能力:S0-CCC ≈ S0-CCT ≥ S0-CTT> S0-TTT ≫ ViT
模型容量:S0-CTT≈S0-TTT>ViT>S0-CCT>S0-CCC
对于泛化来说:卷积层越多,差距越小。
对于模型容量:简单地添加更多的 Transformer 块并不意味着更好的泛化。下图所示的 S0-CTT 被选为这两种功能之间的最佳折衷方案。
from torch import nn, sqrt
import torch
import sys
from math import sqrt
sys.path.append('.')
from model.conv.MBConv import MBConvBlock
from model.attention.SelfAttention import ScaledDotProductAttention
class CoAtNet(nn.Module):
def __init__(self, in_ch, image_size, out_chs=[64, 96, 192, 384, 768]):
super().__init__()
self.out_chs = out_chs
# 最大池化下采样
self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.maxpool1d = nn.MaxPool1d(kernel_size=2, stride=2)
# 卷积提取特征
self.s0 = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1)
)
# 增加维度
self.mlp0 = nn.Sequential(
nn.Conv2d(in_ch, out_chs[0], kernel_size=1),
nn.ReLU(),
nn.Conv2d(out_chs[0], out_chs[0], kernel_size=1)
)
# 卷积模块
# 倒残差结构
self.s1 = MBConvBlock(ksize=3, input_filters=out_chs[0], output_filters=out_chs[0], image_size=image_size // 2)
self.mlp1 = nn.Sequential(
nn.Conv2d(out_chs[0], out_chs[1], kernel_size=1),
nn.ReLU(),
nn.Conv2d(out_chs[1], out_chs[1], kernel_size=1)
)
self.s2 = MBConvBlock(ksize=3, input_filters=out_chs[1], output_filters=out_chs[1], image_size=image_size // 4)
self.mlp2 = nn.Sequential(
nn.Conv2d(out_chs[1], out_chs[2], kernel_size=1),
nn.ReLU(),
nn.Conv2d(out_chs[2], out_chs[2], kernel_size=1)
)
# 自注意力模块
# 四个输入分别为d_model, d_k, d_v, h
# :param d_model: Output dimensionality of the model
# :param d_k: Dimensionality of queries and keys
# :param d_v: Dimensionality of values
# :param h: Number of heads
self.s3 = ScaledDotProductAttention(out_chs[2], out_chs[2] // 8, out_chs[2] // 8, 8)
self.mlp3 = nn.Sequential(
nn.Linear(out_chs[2], out_chs[3]),
nn.ReLU(),
nn.Linear(out_chs[3], out_chs[3])
)
self.s4 = ScaledDotProductAttention(out_chs[3], out_chs[3] // 8, out_chs[3] // 8, 8)
self.mlp4 = nn.Sequential(
nn.Linear(out_chs[3], out_chs[4]),
nn.ReLU(),
nn.Linear(out_chs[4], out_chs[4])
)
def forward(self, x):
B, C, H, W = x.shape
# stage0
y = self.mlp0(self.s0(x))
y = self.maxpool2d(y)
# 倒残差模块
# stage1
y = self.mlp1(self.s1(y))
y = self.maxpool2d(y)
# stage2
y = self.mlp2(self.s2(y))
y = self.maxpool2d(y) # [1,192,28,28]
# stage3
y = y.reshape(B, self.out_chs[2], -1).permute(0, 2, 1) # B,N,C [1,784,192]
y = self.mlp3(self.s3(y, y, y)) # [1,784,384]
y = self.maxpool1d(y.permute(0, 2, 1)).permute(0, 2, 1) # [1,392,384]
# stage4
y = self.mlp4(self.s4(y, y, y))
y = self.maxpool1d(y.permute(0, 2, 1))
N = y.shape[-1]
y = y.reshape(B, self.out_chs[4], int(sqrt(N)), int(sqrt(N)))
return y
if __name__ == '__main__':
x = torch.randn(1, 3, 224, 224)
coatnet = CoAtNet(3, 224)
y = coatnet(x)
print(y.shape)