最近在学习目标检测领域的yolov5算法,发现PSA(极化自注意力机制)对于该算法的改进可能有用,于是在网上几经搜寻,无果,遂自己动手写了一个,现分享给大家
论文链接: Polarized Self-Attention: Towards High-quality Pixel-wise Regression
代码地址: https://github.com/DeLightCMU/PSA
|
|
|
作者在网上没有找到pytorch框架下的PSA模块源码,于是根据论文中的流程自己动手写了一个。
论文中的流程图:
class PSA_Channel(nn.Module):
def __init__(self, c1) -> None:
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = nn.Conv2d(c1, c_, 1)
self.cv2 = nn.Conv2d(c1, 1, 1)
self.cv3 = nn.Conv2d(c_, c1, 1)
self.reshape1 = nn.Flatten(start_dim=-2, end_dim=-1)
self.reshape2 = nn.Flatten()
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(1)
self.layernorm = nn.LayerNorm([c1, 1, 1])
def forward(self, x): # shape(batch, channel, height, width)
x1 = self.reshape1(self.cv1(x)) # shape(batch, channel/2, height*width)
x2 = self.softmax(self.reshape2(self.cv2(x))) # shape(batch, height*width)
y = torch.matmul(x1, x2.unsqueeze(-1)).unsqueeze(-1) # 高维度下的矩阵乘法(最后两个维度相乘)
return self.sigmoid(self.layernorm(self.cv3(y))) * x
class PSA_Spatial(nn.Module):
def __init__(self, c1) -> None:
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = nn.Conv2d(c1, c_, 1)
self.cv2 = nn.Conv2d(c1, c_, 1)
self.reshape1 = nn.Flatten(start_dim=-2, end_dim=-1)
self.globalPooling = nn.AdaptiveAvgPool2d(1)
self.softmax = nn.Softmax(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x): # shape(batch, channel, height, width)
x1 = self.reshape1(self.cv1(x)) # shape(batch, channel/2, height*width)
x2 = self.softmax(self.globalPooling(self.cv2(x)).squeeze(-1)) # shape(batch, channel/2, 1)
y = torch.bmm(x2.permute(0,2,1), x1) # shape(batch, 1, height*width)
return self.sigmoid(y.view(x.shape[0], 1, x.shape[2], x.shape[3])) * x
class PSA(nn.Module):
def __init__(self, in_channel, parallel=True) -> None:
super().__init__()
self.parallel = parallel
self.channel = PSA_Channel(in_channel)
self.spatial = PSA_Spatial(in_channel)
def forward(self, x):
if(self.parallel):
return self.channel(x) + self.spatial(x)
return self.spatial(self.channel(x))
这是我在学习pytorch与yolov5算法过程中的写的一个模块,此外还有更多在码云上
仓库链接:各函数、模块例子
还望各位大佬不吝赐教