DBnet实现

DBnet的具体实现

1.FPN(主干为resnet50)
DBnet实现_第1张图片
2.DB (两次二值化,得到prob_map,threshold_map)
DBnet实现_第2张图片

3.Segout
torch.reciprocal(1 + torch.exp(-k * (prob_map - threshold_map)))

pytorch实现

import torch
import torch.nn as nn

#class:Res,Resnet50,FPN,SegoutDetector->DBnet

class DBnet(nn.Module):
    def __init__(self,serial=False):
        super(DBnet,self).__init__()
        self.backbone=Resnet50()
        self.head=FPN()
        self.seg_out=SegoutDetector(serial=serial)

    #返回prob_map, threshold_map, ab_map,测试时只返回prob_map
    def forward(self,x):
        return self.seg_out(self.head(self.backbone(x)))

class Res(nn.Module):
    #stride=2时缩小特征图尺寸
    def __init__(self,in_channel,inner_channel,stride=1,):
        super(Res,self).__init__()
        self.expansion = 4
        self.bottleneck=nn.Sequential(
            nn.Conv2d(in_channel,inner_channel,1,bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_channel,inner_channel,3,stride,1,bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_channel,self.expansion*inner_channel,1,bias=False),
            nn.BatchNorm2d(self.expansion*inner_channel),
        )
        self.relu=nn.ReLU(inplace=True)
        #若输入通道与输出通道数不同或输入输出尺寸发生变化,对原图进行下采样,再相加
        self.dsample=None
        if stride != 1 or in_channel != self.expansion * inner_channel:
            self.dsample = nn.Sequential(
                nn.Conv2d(in_channel, self.expansion * inner_channel, 1, stride, bias=False),
                nn.BatchNorm2d(self.expansion * inner_channel)
            )

    def forward(self,x):
        identity=x
        out=self.bottleneck(x)
        if self.dsample is not None :
            identity=self.dsample(x)
        out+=identity
        out=self.relu(out)
        return out

class Resnet50(nn.Module):
    def __init__(self):
        super(Resnet50,self).__init__()
        self.make_c1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3,bias=False),
                              nn.BatchNorm2d(64),
                              nn.ReLU(inplace=True),
                              nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
                              )
        #[3,4,6,3]
        self.make_c2=nn.Sequential(Res(in_channel=64,inner_channel=64,stride=1),
                              Res(in_channel=256, inner_channel=64, stride=1),
                              Res(in_channel=256, inner_channel=64, stride=1)
                              )
        self.make_c3=nn.Sequential(Res(in_channel=256,inner_channel=128,stride=2),
                              Res(in_channel=512,inner_channel=128,stride=1),
                              Res(in_channel=512, inner_channel=128, stride=1),
                              Res(in_channel=512, inner_channel=128, stride=1),
                              )
        self.make_c4=nn.Sequential(Res(in_channel=512,inner_channel=256,stride=2),
                              Res(in_channel=1024,inner_channel=256,stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              )
        self.make_c5=nn.Sequential(Res(in_channel=1024,inner_channel=512,stride=2),
                              Res(in_channel=2048,inner_channel=512,stride=1),
                              Res(in_channel=2048, inner_channel=512, stride=1),
                              )

    def forward(self,x):
        c1=self.make_c1(x)
        c2=self.make_c2(c1)
        c3=self.make_c3(c2)
        c4=self.make_c4(c3)
        c5=self.make_c5(c4)
        return c2,c3,c4,c5

class FPN(nn.Module):
    def __init__(self):
        super(FPN,self).__init__()
        self.make_p5=nn.Conv2d(512*4,256,1,1,0)
        #横向连接,保证通道数相同
        self.lat_c4=nn.Conv2d(1024,256,1,1,0)
        self.lat_c3=nn.Conv2d(512,256,1,1,0)
        self.lat_c2=nn.Conv2d(256,256,1,1,0)
        #3x3卷积融合特征
        self.smooth1=nn.Conv2d(256,256,3,1,1)
        self.smooth2 = nn.Conv2d(256, 256, 3, 1, 1)
        self.smooth3 = nn.Conv2d(256, 256, 3, 1, 1)

    def _upsample_add(self,x,y):
        _,_,H,W=y.shape
        upsample=nn.Upsample(size=(H,W))
        y+=upsample(x)
        return y

    def forward(self,x):
        c2, c3, c4, c5=x
        p5=self.make_p5(c5)
        p4=self.smooth1(self._upsample_add(p5,self.lat_c4(c4)))
        p3=self.smooth2(self._upsample_add(p4,self.lat_c3(c3)))
        p2=self.smooth3(self._upsample_add(p3,self.lat_c2(c2)))
        return p2,p3,p4,p5

class SegoutDetector(nn.Module):
    def __init__(self,serial=False):
        super(SegoutDetector,self).__init__()
        #True:组合特征图和概率图计算阈值图,False:由特征图直接计算阈值图
        self.serial=serial
        # probability map
        self.binarize = nn.Sequential(
            nn.Conv2d(256, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, 2, 2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels=1, kernel_size=2, stride=2),
            nn.Sigmoid()
        )
        # threshold map
        if self.serial:
            self.threshold = nn.Sequential(
                nn.Conv2d(      257       , 64, 3, 1, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(64, 64, 2, 2, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(64, 1, 2, 2),
                nn.Sigmoid()
            )

    def forward(self, x):
        p2, p3, p4, p5=x
        fuse = self.merge(p2, p3, p4, p5)
        # probability map
        prob_map = self.binarize(fuse)
        #测试时只返回概率图
        if not self.training:
            return prob_map
        # threshold map
        if self.serial:
            #将概率图下采样,与特征图拼接
            fuse = torch.cat((fuse, nn.functional.interpolate(prob_map, fuse.shape[2:])), 1)
            threshold_map=self.threshold(fuse)
        else:
            threshold_map = self.binarize(fuse)
        # approximate binary map
        ab_map = self.ab_map(prob_map, threshold_map)
        return prob_map, threshold_map, ab_map

    def merge(self, p2, p3, p4, p5):
        conv3 = nn.Conv2d(256, 64, 3, 1, 1)
        upsample = nn.Upsample(size=(p2.shape[2], p2.shape[3]))
        p2 = conv3(p2)
        p4 = conv3(p4)
        p4 = upsample(p4)
        p3 = conv3(p3)
        p3 = upsample(p3)
        p5 = conv3(p5)
        p5 = upsample(p5)
        return torch.cat((p2, p3, p4, p5), dim=1)

    # approximate binary map
    def ab_map(self, x, y, k=50):
        return torch.reciprocal(1 + torch.exp(-k * (x - y)))

if __name__=="__main__":
    db=DBnet()
    print(len(db.state_dict()))
    x=torch.randn(2,3,512,512)
    p,t,pt=db(x)
    print(p.shape,t.shape,pt.shape)

# if __name__=="__main__":
#     res=Res(64,64,2)
#     print(res)

# if __name__=="__main__":
#     r=Resnet50()
#     x=torch.randn(1,3,512,512)
#     c2, c3, c4, c5=r(x)
#     print(c2.shape,c3.shape,c4.shape,c5.shape)

# if __name__=="__main__":
#     r=Resnet50()
#     f=FPN()
#     x=torch.randn(1,3,512,512)
#     c2,c3,c4,c5=r(x)
#     p2,p3,p4,p5=f(c2,c3,c4,c5)
#     print(p2.shape)
#     print(p3.shape)
#     print(p4.shape)
#     print(p5.shape)


你可能感兴趣的:(算法,神经网络)