【pytorch】PSA极化自注意力机制

PSA极化自注意力机制

  • PSA极化自注意力机制
    • 使用效果
    • 极化过程示意图
    • 源代码

最近在学习目标检测领域的yolov5算法,发现PSA(极化自注意力机制)对于该算法的改进可能有用,于是在网上几经搜寻,无果,遂自己动手写了一个,现分享给大家

PSA极化自注意力机制

论文链接: Polarized Self-Attention: Towards High-quality Pixel-wise Regression
代码地址: https://github.com/DeLightCMU/PSA

使用效果

【pytorch】PSA极化自注意力机制_第1张图片图1 原图
【pytorch】PSA极化自注意力机制_第2张图片 图2 平行极化
【pytorch】PSA极化自注意力机制_第3张图片图3 顺序极化

极化过程示意图

作者在网上没有找到pytorch框架下的PSA模块源码,于是根据论文中的流程自己动手写了一个。
论文中的流程图:
【pytorch】PSA极化自注意力机制_第4张图片

源代码

class PSA_Channel(nn.Module):
    def __init__(self, c1) -> None:
        super().__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = nn.Conv2d(c1, c_, 1)
        self.cv2 = nn.Conv2d(c1, 1, 1)
        self.cv3 = nn.Conv2d(c_, c1, 1)
        self.reshape1 = nn.Flatten(start_dim=-2, end_dim=-1)
        self.reshape2 = nn.Flatten()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(1)
        self.layernorm = nn.LayerNorm([c1, 1, 1])

    def forward(self, x): # shape(batch, channel, height, width)
        x1 = self.reshape1(self.cv1(x)) # shape(batch, channel/2, height*width)
        x2 = self.softmax(self.reshape2(self.cv2(x))) # shape(batch, height*width)
        y = torch.matmul(x1, x2.unsqueeze(-1)).unsqueeze(-1) # 高维度下的矩阵乘法(最后两个维度相乘)
        return self.sigmoid(self.layernorm(self.cv3(y))) * x

class PSA_Spatial(nn.Module):
    def __init__(self, c1) -> None:
        super().__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = nn.Conv2d(c1, c_, 1)
        self.cv2 = nn.Conv2d(c1, c_, 1)
        self.reshape1 = nn.Flatten(start_dim=-2, end_dim=-1)
        self.globalPooling = nn.AdaptiveAvgPool2d(1)
        self.softmax = nn.Softmax(1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x): # shape(batch, channel, height, width)
        x1 = self.reshape1(self.cv1(x)) # shape(batch, channel/2, height*width)
        x2 = self.softmax(self.globalPooling(self.cv2(x)).squeeze(-1)) # shape(batch, channel/2, 1)
        y = torch.bmm(x2.permute(0,2,1), x1) # shape(batch, 1, height*width)
        return self.sigmoid(y.view(x.shape[0], 1, x.shape[2], x.shape[3])) * x

class PSA(nn.Module):
    def __init__(self, in_channel, parallel=True) -> None:
        super().__init__()
        self.parallel = parallel
        self.channel = PSA_Channel(in_channel)
        self.spatial = PSA_Spatial(in_channel)

    def forward(self, x):
        if(self.parallel):
            return self.channel(x) + self.spatial(x)
        return self.spatial(self.channel(x))

这是我在学习pytorch与yolov5算法过程中的写的一个模块,此外还有更多在码云上
仓库链接:各函数、模块例子

还望各位大佬不吝赐教

你可能感兴趣的:(pytorch,深度学习,python,计算机视觉,人工智能)