计算机视觉之 GSoP 注意力模块

计算机视觉之 GSoP 注意力模块

一、简介

GSopBlock 是一个自定义的神经网络模块,主要用于实现 GSoP(Global Second-order Pooling)注意力机制。GSoP 注意力机制通过计算输入特征的协方差矩阵,捕捉全局二阶统计信息,从而增强模型的表达能力。

原论文:《Global Second-order Pooling Convolutional Networks (arxiv.org)》

二、语法和参数

语法
class GSopBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        ...
    def forward(self, x):
        ...
参数
  • in_channels:输入特征的通道数。
  • mid_channels:中间层的通道数,用于调整特征维度。

三、实例

3.1 初始化和前向传播
  • 代码
import torch
import torch.nn as nn

class GSopBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(GSopBlock, self).__init__()
        self.conv2d1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )
        self.row_wise_conv = nn.Sequential(
            nn.Conv2d(
                mid_channels, 4*mid_channels,
                kernel_size=(mid_channels, 1),
                groups=mid_channels, bias=False),
            nn.BatchNorm2d(4*mid_channels),
        )
        self.conv2d2 = nn.Sequential(
            nn.Conv2d(4*mid_channels, in_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Step 1: 调整通道数
        x = self.conv2d1(x)
        batch_size, channels, height, width = x.size()

        # Step 2: 展平输入
        x_flat = x.view(batch_size, channels, -1)

        # Step 3: 计算协方差矩阵
        x_mean = x_flat.mean(dim=-1, keepdim=True)
        x_centered = x_flat - x_mean
        cov_matrix = torch.bmm(x_centered, x_centered.transpose(1, 2)) / (height * width)
        cov_matrix = cov_matrix.unsqueeze(-1)

        # Step 4: 行方向卷积
        cov_features = self.row_wise_conv(cov_matrix)

        # Step 5: 生成权重向量
        weight_vector = self.conv2d2(cov_features)

        # Step 6: 计算最终输出
        x_out = x * weight_vector

        return x_out
  • 输出
经过加权后的图像
3.2 应用在示例数据上
  • 代码
import torch

# 创建示例输入数据
input_tensor = torch.randn(1, 64, 32, 32)  # (batch_size, in_channels, height, width)

# 初始化 GSopBlock 模块
gsop_block = GSopBlock(in_channels=64, mid_channels=16)

# 前向传播
output_tensor = gsop_block(input_tensor)
print(output_tensor.shape)
  • 输出
torch.Size([1, 64, 32, 32])

四、注意事项

  1. GSopBlock 模块适用于捕捉输入特征之间的全局二阶统计信息,增强模型的表达能力。
  2. 在使用 GSopBlock 时,确保输入特征的通道数和中间层的通道数设置合理,以避免计算开销过大。
  3. 该模块主要用于图像数据处理,适用于各种计算机视觉任务,如图像分类、目标检测等。

你可能感兴趣的:(计算机视觉(CV),深度学习,机器学习,人工智能)