计算机视觉 - Attention机制(附代码)

Attention机制

  • 1.Attention简介
  • 2.Attention原理
  • 3.Attention的不同类型
  • 4.CBAM实现(Pytorch)

1.Attention简介

Attention中文意思为注意力,这个机制放到计算机视觉里,类似于给我们看一张美女帅哥的图片,我们第一眼首先关注的地方是这个人的哪里呢

你们第一眼看的是哪里呢
计算机视觉 - Attention机制(附代码)_第1张图片

最早attention机制就应用到计算机视觉中,这里说的机制,其实就是神经网络中一个模块,类似于U-Net加上attention机制的变化。
计算机视觉 - Attention机制(附代码)_第2张图片
计算机视觉 - Attention机制(附代码)_第3张图片

看出什么变化了吗,其实就是在原始的网络结构增加一些结构模块。
随着NLP领域的发展,也开始应用了atteniton机制,除了这个还有循环神经网络(RNNs)和门控循环单元(GRUs)﹑长短期记忆(LSTMs)﹑序列对序列(Seq2Seq)﹑记忆网络(Memory Networks)等。这些都是Encoder-Decoder的不同框架。

但是attention是可以脱离Encoder-Decoder,被其他模型框架使用的。

2.Attention原理

在平时最常用的淘宝,得物的照片识别,其实算法都是使用attention这个机制的。

Attention 原理的3步分解
计算机视觉 - Attention机制(附代码)_第4张图片
第一步: query 和 key 进行相矩阵相乘

第二步:将矩阵相乘得到的结果根据不同权重进行归一化

第三步:将结果和 value 再进行一次矩阵相乘

这里步骤中提到的query、key、value其实就是我们的feature maps分别跟1x1的卷积核卷积得到的三个向量。如下图所示。
计算机视觉 - Attention机制(附代码)_第5张图片
整体步骤,可以这样理解:

  1. 第一步我们先生成一个包含像素特征的图像value
  2. 第二步,我们生成出我们需要找的特征图像query,比如说得物,我们需要找到图像中的鞋的细节(鞋底,鞋带。。。)。
  3. 第三步,我们给图像中所有的特征都做一个编号key。
  4. 我们的方式就是通过query去查找到图像中key,提取我们需要的key,并与value结合,利用权重,得到实际我们想要查找的图像关键区域。

说白了,在attention机制就是一种特征图的权重分布,把有用的特征权重加大,没有的特征权重加小,再用学出来的权重施加在原特征图之上最后进行加权求和。

3.Attention的不同类型

目前attention已经应用到计算机视觉,自然语言处理等多个领域,这些不同领域的应用,虽然attention的结构不变,但是其中的query、key、value的计算方式是不同的。计算区域也不同(一个卷积核乘积,不是所有的feature maps都做乘积)。
计算机视觉 - Attention机制(附代码)_第6张图片

前面attention原理,介绍的是attention的通用版本。这里我只提计算机视觉方面的attention,在计算机视觉中,主要有三种attention,分别为:

  • spatial attention:对于卷积神经网络,CNN每一层都会输出一个C x H x W的特征图,C就是通道,同时也代表卷积核的数量,亦为特征的数量,H 和W就是原始图片经过压缩后的图的高度和宽度,spatial attention就是对于所有的通道,在二维平面上,对H x W尺寸的特征图学习到一个权重,对每个像素都会学习到一个权重。你可以想象成一个像素是C维的一个向量,深度是C,在C个维度上,权重都是一样的,但是在平面上,权重不一样。
  • channel attention:就是对每个C(通道),在channel维度上,学习到不同的权重,平面维度上权重相同。所以基于通道域的注意力通常是对一个通道内的信息直接全局平均池化,而忽略每一个通道内的局部信息。SENet算法就是使用的channel attention。
  • spatial attention与channel attention融合:CBAM(Convolutional Block Attention Module)[5] 是其中的代表性网络,结构如下:
    计算机视觉 - Attention机制(附代码)_第7张图片
    其中Channel Attention Module模块:
    计算机视觉 - Attention机制(附代码)_第8张图片
    同时使用最大 pooling 和均值 pooling 算法,然后经过几个 MLP 层获得变换结果,最后分别应用于两个通道,使用 sigmoid 函数得到通道的 attention 结果。

其中Spatial Attention Module模块:
计算机视觉 - Attention机制(附代码)_第9张图片
首先将通道本身进行降维,分别获取最大池化和均值池化结果,然后拼接成一个特征图,再使用一个卷积层进行学习。

这两种机制,分别学习了通道的重要性和空间的重要性,还可以很容易地嵌入到任何已知的框架中。

4.CBAM实现(Pytorch)

CBAM模块详细:

其中Channel Attention模块:

其中Spatial Attention模块:
计算机视觉 - Attention机制(附代码)_第10张图片
代码如下:

import torch 
import torch.nn as nn
import torchvision

#ratio 为通道数
class ChannelAttention(nn.Moudel):
	def __init__(self, channel, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        print(avgout.shape)
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out


class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        print('outchannels:{}'.format(out.shape))
        out = self.spatial_attention(out) * out
        return out

你可能感兴趣的:(Pytorch,图像处理,pytorch)