论文题目:
基于全局和局部残差图像预测的红外小型无人机目标检测(Infrared Small UAV Target Detection Based on Residual Image Prediction via Global and Local)
提出网络结构:DRUNet
论文地址:https://ieeexplore.ieee.org/document/9452107
论文复现:GitHub
热红外成像具有在白天和夜间条件下监测无人机 (UAV) 的能力。然而,红外无人机的远程探测往往受到目标小/暗淡、杂波大、复杂背景噪声的影响。传统的基于局部先验和非局部先验的方法通常具有较高的误报率和较低的检测精度。在这篇文章中,我们提出了一个模型,将小型无人机检测转换为预测残差图像(即背景、杂波和噪声)的问题。这种新颖的结构使我们能够直接学习从输入红外图像到残差图像的映射。构建的图像到图像网络将全局和局部的残差卷积块集成到 U-Net 中,可以很好地捕获局部和上下文结构信息,并融合不同尺度的特征进行图像重建。此外,利用亚像素卷积来放大图像并避免上采样过程中的图像失真。最后,通过从输入的红外图像中减去残差图像,得到小型无人机目标图像。对比实验表明,所提出的方法在检测具有重杂波和暗淡目标的真实红外图像方面优于最先进的方法。
我们的工作受到用于图像恢复的残差图像学习的启发[11]。 [11]中的 DnCNN 网络最初用于预测图像损坏(例如,噪声)而不是干净的图像,这通常是一项更容易的任务。 在这封信中,我们将红外小目标检测建模为 CNN 预测残差图像。 背景、杂波和噪声被视为输入红外图像与小型无人机目标图像之间的残差图像。 一旦估计残差图像,就通过从输入图像中减去残差图像得到小目标图像。
由于其结构复杂,直接使用 DnCNN 网络 [11] 预测残差图像时很难获得令人满意的性能。为了解决上述问题,我们提出了一种将全局和局部残差卷积块合并到 U-Net 网络中的模型。全局和局部的残差卷积块,以及编码和解码阶段之间的跳跃连接,防止梯度在反向传播过程中消失或爆炸,使网络训练更加稳定。特别是全局残差连接(GRC)可以增强浅层信息向深层的流动,减少特征图信息的丢失。此外,引入带有扩张卷积的局部残差块可以扩大感受野并捕获更多的上下文信息来重建残差图像的多个结构分量,而无需增加网络深度。此外,采用高效子像素卷积(ESPC)[12] 来放大图像并避免解码阶段的图像失真。总之,这篇论文的主要贡献如下:
- 提出了一种新的多尺度U-Net架构,它可以预测输入图像与目标图像之间的残差图像,用于红外小型无人机目标检测。
- 提出了一个由两个连续的局部残块组成的全局残块,它不仅可以捕获局部特征和上下文特征,还可以通过GRC更好地融合前一层和当前层的特征。
- 提出的方法在具有重杂波、复杂建筑和暗淡目标的真实红外无人机图像的定性和定量评估方面明显优于最先进的方法。
import torch
import torch.nn as nn
class InputCov(nn.Module):
"""
处理原始图像
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.input_cov = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=False),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=2, stride=1, dilation=2),
nn.ReLU(inplace=False),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=2, stride=1, dilation=2)
)
def forward(self, initial_data):
input_cov = self.input_cov(initial_data) # (batch,C,W,H)
input_cov _ = torch.add(input_cov , initial_data) #特征融合
return input_cov _
class StdCovLocalResBlock(nn.Module):
"""
使用标准卷积的局部残差模块
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.std_lrb = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=False),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1)
)
def forward(self, input_map):
std_lrb = self.std_lrb(input_map) # (btch,C,W,H)
add_map = torch.sum(input_map, dim=1) # (batch,W,H)
add_map_1 = torch.div(add_map, input_map.shape[1]) #加和求均值
add_map_2 = add_map_1.unsqueeze(1) # (batch,1,W,H)
output_map= torch.add(std_lrb , add_map_2) #特征融合
return output_map
class DilCovLocalResBlock(nn.Module):
"""
使用空洞卷积的局部残差模块
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.dil_lrb = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=2, stride=1, dilation=2),
nn.ReLU(inplace=False),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=2, stride=1, dilation=2)
)
def forward(self, input_map):
dil_lrb= self.dil_lrb(input_map) # (abtch,C,W,H)
add_map = torch.sum(input_map, dim=1) # (batch_size,W,H)
add_map_1 = torch.div(add_map, input_map.shape[1])
add_map_2 = add_map_1.unsqueeze(1) # (batch,1,W,H)
output_map= torch.add(dil_lrb, add_map_2)
return output_map
class LeftGlobalResBlock(nn.Module):
"""
网络左半部分的全局残差模块
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.std_cov = StdCovLocalResBlock(in_channels, out_channels)
self.dil_cov = DilCovLocalResBlock(out_channels, out_channels)
def forward(self, down_map):
std_cov = self.std_cov(down_map)
dil_cov = self.dil_cov(std_cov)
add_map = torch.sum(down_map, dim=1)
add_map_1 = torch.div(add_map, down_map.shape[1])
add_map_2 = add_map_1.unsqueeze(1) # (batch,1,W,H)
output_map= torch.add(dil_cov, add_map_2)
return output_map
class RightGlobalResBlock(nn.Module):
"""
网络右半部分的全局残差模块
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.std_cov_1 = StdCovLocalResBlock(in_channels, out_channels)
self.std_cov_2 = StdCovLocalResBlock(out_channels, out_channels)
def forward(self, up_map):
std_cov_1 = self.std_cov_1(up_map)
std_cov_2 = self.std_cov_2(std_cov_1)
add_map = torch.sum(up_map, dim=1)
add_map_1 = torch.div(add_map, up_map.shape[1])
add_map_2 = add_map_1.unsqueeze(1) # (batch,1,W,H)
output_map= torch.add(std_cov_2, add_map_2)
return output_map
class Down(nn.Module):
"""
下采样模块
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.down = nn.Conv2d(in_channels, out_channels, kernel_size=2, padding=0, stride=2)
def forward(self, input_map):
return self.down(input_map)
class Up(nn.Module):
"""
上采样模块
"""
def __init__(self):
super().__init__()
self.up = nn.PixelShuffle(2)
def forward(self, input_map, skip_map):
up = self.up(input_map)
up_map = torch.sum(up, dim=1)
up_map_1 = torch.div(up_map, input_map.shape[1])
up_map_2 = up_map_1.unsqueeze(1) # (batch,1,W,H)
output_map= torch.add(skip_map, up_map_2)
return output_map
class OutputCov(nn.Module):
"""
输出
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.cov = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1)
def forward(self, input_map):
return self.cov(input_map)
from .DRUNet_parts import *
class DRUNet(nn.Module):
def __init__(self):
super(DRUNet, self).__init__()
self.input = InputCov(1, 64)
self.down1 = Down(64, 128)
self.left1 = LeftGlobalResBlock(128, 128)
self.down2 = Down(128, 256)
self.left2 = LeftGlobalResBlock(256, 256)
self.down3 = Down(256, 512)
self.left3 = LeftGlobalResBlock(512, 512)
self.up1 = Up()
self.right1 = RightGlobalResBlock(256, 256)
self.up2 = Up()
self.right2 = RightGlobalResBlock(128, 128)
self.up3 = Up()
self.right3 = RightGlobalResBlock(64, 64)
self.output = OutputCov(64, 1)
def forward(self, init_img):
input_map = self.input(init_img)
down1_map = self.down1(input_map)
left1_map = self.left1(down1_map)
down2_map = self.down2(left1_map)
left2_map = self.left2(down2_map)
down3_map = self.down3(left2_map)
left3_map = self.left3(down3_map)
up1_map = self.up1(left3_map, left2_map)
right1_map = self.right1(up1_map)
up2_map = self.up2(right1_map, left1_map)
right2_map = self.right2(up2_map)
up3_map = self.up3(right2_map, input_map)
right3_map = self.right3(up3_map)
output = self.output(right3_map)
return output
if __name__ == '__main__':
net = DRUNet()
input_ = torch.Tensor(3, 1, 512, 512)
out = net(input_)
print(out)