【模块】Non-local Neural

论文《Non-local Neural Networks》

作用

非局部神经网络通过非局部操作捕获长距离依赖,这对于深度神经网络来说至关重要。这些操作允许模型在空间、时间或时空中的任何位置间直接计算相互作用,从而捕获长距离的交互和依赖关系。这种方法对于视频分类、对象检测/分割以及姿态估计等任务表现出了显著的改进。

机制

非局部操作通过在输入特征图的所有位置上计算响应的加权和来实现,其中权重由位置之间的关系(如相似性)确定。这种操作可以直接插入许多计算机视觉架构中。在视频处理应用中,非局部块(基本单位)可以直接以前馈方式捕获时空依赖性。

独特优势

1、直接捕获长距离依赖

与重复应用局部操作(如卷积和递归操作)逐渐传递信号不同,非局部操作可以直接处理任意两个位置间的相互作用,无论它们的位置距离有多远。

2、计算效率高

尽管能够处理长距离依赖,但非局部模型在只有几层的情况下即可达到最佳效果,例如,在视频分类任务中,即使没有任何额外技巧,非局部模型也能与当前的竞争者相抗衡或超越。

3、保持输入大小的灵活性

非局部操作支持可变大小的输入,并能保持输出大小与输入相同,使其易于与其他操作(例如卷积)结合使用。

4、在多种任务上的通用性和有效性

无论是在动态视频还是静态图像识别任务上,加入非局部块的模型都显示出了对基线模型的明显改进,同时额外的计算成本很小。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd

# 非局部注意力模块的实现
class _NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]# 断言,确保维度为1,2,或3

        self.dimension = dimension# 保存维度信息
        self.sub_sample = sub_sample# 是否进行子采样

        self.in_channels = in_channels# 输入通道数
        self.inter_channels = inter_channels # 中间通道数
        # 如果没有指定中间通道数,则默认为输入通道数的一半,但至少为1
        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        conv_nd = nn.Conv2d
        max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))# 最大池化层
        bn = nn.BatchNorm2d# 批归一化
        # g函数:降维
        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)
        # 如果使用批归一化
        if bn_layer:
            # W函数:升维并使用批归一化
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0) # 初始化W函数权重为0
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)

        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        self.concat_project = nn.Sequential(
            nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
            nn.ReLU()
        )
        # 如果进行子采样,则在g和phi函数后添加最大池化层
        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x, return_nl_map=False):


        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # (b, c, N, 1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)# 在宽度维度上重复
        # (b, c, 1, N)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)

        h = theta_x.size(2)
        w = phi_x.size(3)
        theta_x = theta_x.repeat(1, 1, 1, w)
        phi_x = phi_x.repeat(1, 1, h, 1)

        concat_feature = torch.cat([theta_x, phi_x], dim=1)
        f = self.concat_project(concat_feature)
        b, _, h, w = f.size()
        f = f.view(b, h, w)

        N = f.size(-1)
        f_div_C = f / N

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        if return_nl_map:
            return z, f_div_C
        return z


if __name__ == '__main__':

    a = torch.ones(3, 32, 20, 20)  #生成随机数
    b = _NonLocalBlockND(32, 32)   #实例化
    c = b(a)
    print(c.size())

你可能感兴趣的:(扒网络模块,深度学习,pytorch,python)