1.FPN(主干为resnet50)
2.DB (两次二值化,得到prob_map,threshold_map)
3.Segout
torch.reciprocal(1 + torch.exp(-k * (prob_map - threshold_map)))
import torch
import torch.nn as nn
#class:Res,Resnet50,FPN,SegoutDetector->DBnet
class DBnet(nn.Module):
def __init__(self,serial=False):
super(DBnet,self).__init__()
self.backbone=Resnet50()
self.head=FPN()
self.seg_out=SegoutDetector(serial=serial)
#返回prob_map, threshold_map, ab_map,测试时只返回prob_map
def forward(self,x):
return self.seg_out(self.head(self.backbone(x)))
class Res(nn.Module):
#stride=2时缩小特征图尺寸
def __init__(self,in_channel,inner_channel,stride=1,):
super(Res,self).__init__()
self.expansion = 4
self.bottleneck=nn.Sequential(
nn.Conv2d(in_channel,inner_channel,1,bias=False),
nn.BatchNorm2d(inner_channel),
nn.ReLU(inplace=True),
nn.Conv2d(inner_channel,inner_channel,3,stride,1,bias=False),
nn.BatchNorm2d(inner_channel),
nn.ReLU(inplace=True),
nn.Conv2d(inner_channel,self.expansion*inner_channel,1,bias=False),
nn.BatchNorm2d(self.expansion*inner_channel),
)
self.relu=nn.ReLU(inplace=True)
#若输入通道与输出通道数不同或输入输出尺寸发生变化,对原图进行下采样,再相加
self.dsample=None
if stride != 1 or in_channel != self.expansion * inner_channel:
self.dsample = nn.Sequential(
nn.Conv2d(in_channel, self.expansion * inner_channel, 1, stride, bias=False),
nn.BatchNorm2d(self.expansion * inner_channel)
)
def forward(self,x):
identity=x
out=self.bottleneck(x)
if self.dsample is not None :
identity=self.dsample(x)
out+=identity
out=self.relu(out)
return out
class Resnet50(nn.Module):
def __init__(self):
super(Resnet50,self).__init__()
self.make_c1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)
#[3,4,6,3]
self.make_c2=nn.Sequential(Res(in_channel=64,inner_channel=64,stride=1),
Res(in_channel=256, inner_channel=64, stride=1),
Res(in_channel=256, inner_channel=64, stride=1)
)
self.make_c3=nn.Sequential(Res(in_channel=256,inner_channel=128,stride=2),
Res(in_channel=512,inner_channel=128,stride=1),
Res(in_channel=512, inner_channel=128, stride=1),
Res(in_channel=512, inner_channel=128, stride=1),
)
self.make_c4=nn.Sequential(Res(in_channel=512,inner_channel=256,stride=2),
Res(in_channel=1024,inner_channel=256,stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
)
self.make_c5=nn.Sequential(Res(in_channel=1024,inner_channel=512,stride=2),
Res(in_channel=2048,inner_channel=512,stride=1),
Res(in_channel=2048, inner_channel=512, stride=1),
)
def forward(self,x):
c1=self.make_c1(x)
c2=self.make_c2(c1)
c3=self.make_c3(c2)
c4=self.make_c4(c3)
c5=self.make_c5(c4)
return c2,c3,c4,c5
class FPN(nn.Module):
def __init__(self):
super(FPN,self).__init__()
self.make_p5=nn.Conv2d(512*4,256,1,1,0)
#横向连接,保证通道数相同
self.lat_c4=nn.Conv2d(1024,256,1,1,0)
self.lat_c3=nn.Conv2d(512,256,1,1,0)
self.lat_c2=nn.Conv2d(256,256,1,1,0)
#3x3卷积融合特征
self.smooth1=nn.Conv2d(256,256,3,1,1)
self.smooth2 = nn.Conv2d(256, 256, 3, 1, 1)
self.smooth3 = nn.Conv2d(256, 256, 3, 1, 1)
def _upsample_add(self,x,y):
_,_,H,W=y.shape
upsample=nn.Upsample(size=(H,W))
y+=upsample(x)
return y
def forward(self,x):
c2, c3, c4, c5=x
p5=self.make_p5(c5)
p4=self.smooth1(self._upsample_add(p5,self.lat_c4(c4)))
p3=self.smooth2(self._upsample_add(p4,self.lat_c3(c3)))
p2=self.smooth3(self._upsample_add(p3,self.lat_c2(c2)))
return p2,p3,p4,p5
class SegoutDetector(nn.Module):
def __init__(self,serial=False):
super(SegoutDetector,self).__init__()
#True:组合特征图和概率图计算阈值图,False:由特征图直接计算阈值图
self.serial=serial
# probability map
self.binarize = nn.Sequential(
nn.Conv2d(256, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, 2, 2, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, out_channels=1, kernel_size=2, stride=2),
nn.Sigmoid()
)
# threshold map
if self.serial:
self.threshold = nn.Sequential(
nn.Conv2d( 257 , 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, 2, 2, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 1, 2, 2),
nn.Sigmoid()
)
def forward(self, x):
p2, p3, p4, p5=x
fuse = self.merge(p2, p3, p4, p5)
# probability map
prob_map = self.binarize(fuse)
#测试时只返回概率图
if not self.training:
return prob_map
# threshold map
if self.serial:
#将概率图下采样,与特征图拼接
fuse = torch.cat((fuse, nn.functional.interpolate(prob_map, fuse.shape[2:])), 1)
threshold_map=self.threshold(fuse)
else:
threshold_map = self.binarize(fuse)
# approximate binary map
ab_map = self.ab_map(prob_map, threshold_map)
return prob_map, threshold_map, ab_map
def merge(self, p2, p3, p4, p5):
conv3 = nn.Conv2d(256, 64, 3, 1, 1)
upsample = nn.Upsample(size=(p2.shape[2], p2.shape[3]))
p2 = conv3(p2)
p4 = conv3(p4)
p4 = upsample(p4)
p3 = conv3(p3)
p3 = upsample(p3)
p5 = conv3(p5)
p5 = upsample(p5)
return torch.cat((p2, p3, p4, p5), dim=1)
# approximate binary map
def ab_map(self, x, y, k=50):
return torch.reciprocal(1 + torch.exp(-k * (x - y)))
if __name__=="__main__":
db=DBnet()
print(len(db.state_dict()))
x=torch.randn(2,3,512,512)
p,t,pt=db(x)
print(p.shape,t.shape,pt.shape)
# if __name__=="__main__":
# res=Res(64,64,2)
# print(res)
# if __name__=="__main__":
# r=Resnet50()
# x=torch.randn(1,3,512,512)
# c2, c3, c4, c5=r(x)
# print(c2.shape,c3.shape,c4.shape,c5.shape)
# if __name__=="__main__":
# r=Resnet50()
# f=FPN()
# x=torch.randn(1,3,512,512)
# c2,c3,c4,c5=r(x)
# p2,p3,p4,p5=f(c2,c3,c4,c5)
# print(p2.shape)
# print(p3.shape)
# print(p4.shape)
# print(p5.shape)