CNN与Transformer的结合-BoTNet

简单学习BoTNet

1.简单介绍
如今transformer热度较高,越来越多的人都来研究transformer,并将其用于CV领域,下图(出自BoTNet论文)就介绍了当前self-Attention在CV领域的应用,而本篇文章的主角BoTNet就是利用CNN+transformer的方式提出一种Bottleneck Transformer来代替ResNet Bottleneck。
CNN与Transformer的结合-BoTNet_第1张图片
2.大体结构
如图所示总体结构比较简单,就是在ResNet-50的结构基础上将C5的三个Bottleneck替换为带MHSA的Bottleneck。

CNN与Transformer的结合-BoTNet_第2张图片
这里C5的Bottleneck也不一定3个全用MHSA,为此文章也做了对比实验
CNN与Transformer的结合-BoTNet_第3张图片
实验结果
CNN与Transformer的结合-BoTNet_第4张图片

带MHSA的Bottleneck如下图所示
CNN与Transformer的结合-BoTNet_第5张图片
接下来我们看看MHSA是怎么具体实现的
CNN与Transformer的结合-BoTNet_第6张图片
结构大体与transformer一致,不过值得注意的是这个content -position与transformer的位置嵌入还是有区别的,有兴趣的同学可以去单独了解下content-position。
MHSA的代码实现如下

class MHSA(nn.Module):
    def __init__(self, n_dims, width, height):
        super(MHSA, self).__init__()

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        # nn.Parameter 含义是将一个固定不可训练的tensor转换成可以训练的类型parameter
        self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, C, -1)
        k = self.key(x).view(n_batch, C, -1)
        v = self.value(x).view(n_batch, C, -1)
        # 对存储在两个批bach1和batch内的矩阵进行批矩阵乘操作。batch1和2都包含相同数量矩阵的三维张量
        content_content = torch.bmm(q.permute(0, 2, 1), k)

        content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(n_batch, C, width, height)

        return out

这个代码快也还是比较容易理解的
3.实验结果
单独使用的BoTNet在目标检测中取得了不错的结果
CNN与Transformer的结合-BoTNet_第7张图片
但在图像分类中效果取并没有很强,为此文章又提出了BoTNet-S1结构哦,如下图
CNN与Transformer的结合-BoTNet_第8张图片
CNN与Transformer的结合-BoTNet_第9张图片

这里将stride由2改为了1,改为1后其实就是上图的平均池化还有卷积的stride改为了1,然而却取得了非常不错得效果
CNN与Transformer的结合-BoTNet_第10张图片

另外文章还有做了一些其他的对比实验,就放在下面了
其中BoTNet的卷积块添加了SE(注意力机制)模块
CNN与Transformer的结合-BoTNet_第11张图片
CNN与Transformer的结合-BoTNet_第12张图片
(上述有发现任何错误和问题可随时联系我)

你可能感兴趣的:(transformer,深度学习,计算机视觉,人工智能)