实例分割论文:PolarMask: Single Shot Instance Segmentation with Polar Representation及其Pytorch实现

PolarMask: Single Shot Instance Segmentation with Polar Representation
PDF: https://arxiv.org/pdf/1909.13226.pdf
PyTorch: https://github.com/xieenze/PolarMask
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks
实例分割论文:PolarMask: Single Shot Instance Segmentation with Polar Representation及其Pytorch实现_第1张图片
创新点:
1 把二阶段的’‘先检测再分割’'实例分割检测框架(如MaskRCNN), 转换为一阶段框架,同时与目标检测相比计算量没有明显增加.
2 提出基于极坐标系建模轮廓的实例分割方式,将目标检测的4根射线散发到36根射线.

  • 与像素级建模相比,极坐标系建模轮廓使问题简单化,且计算量更小;
  • 与直角坐标系建模轮廓相比,极坐标系建模轮廓可以利用固定角度先验信息,进一步简化问题;
    实例分割论文:PolarMask: Single Shot Instance Segmentation with Polar Representation及其Pytorch实现_第2张图片
    3 提出了一种新的Polar Centerness用来选取高质量的正样本
    实例分割论文:PolarMask: Single Shot Instance Segmentation with Polar Representation及其Pytorch实现_第3张图片
    4 提出了Polar IoU Loss近似计算出predict mask和gt mask的iou,通过Iou Loss 更好的优化mask的回归,实验表明Polar IoU Loss比Smooth L1loss更高效.
    实例分割论文:PolarMask: Single Shot Instance Segmentation with Polar Representation及其Pytorch实现_第4张图片

PyTorch代码:

import torch
import torch.nn as nn
import torchvision

def Conv3x3ReLU(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
        nn.ReLU6(inplace=True)
    )

def locLayer(in_channels,out_channels):
    return nn.Sequential(
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        )

def conf_centernessLayer(in_channels,out_channels):
    return nn.Sequential(
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
    )

class PolarMask(nn.Module):
    def __init__(self, num_classes=21):
        super(PolarMask, self).__init__()
        self.num_classes = num_classes
        resnet = torchvision.models.resnet50()
        layers = list(resnet.children())

        self.layer1 = nn.Sequential(*layers[:5])
        self.layer2 = nn.Sequential(*layers[5])
        self.layer3 = nn.Sequential(*layers[6])
        self.layer4 = nn.Sequential(*layers[7])

        self.lateral5 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1)
        self.lateral4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
        self.lateral3 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)

        self.upsample4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.upsample3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)

        self.downsample6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.downsample5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)

        self.loc_layer3 = locLayer(in_channels=256,out_channels=36)
        self.conf_centerness_layer3 = conf_centernessLayer(in_channels=256,out_channels=self.num_classes)

        self.loc_layer4 = locLayer(in_channels=256, out_channels=36)
        self.conf_centerness_layer4 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)

        self.loc_layer5 = locLayer(in_channels=256, out_channels=36)
        self.conf_centerness_layer5 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)

        self.loc_layer6 = locLayer(in_channels=256, out_channels=36)
        self.conf_centerness_layer6 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)

        self.loc_layer7 = locLayer(in_channels=256, out_channels=36)
        self.conf_centerness_layer7 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.layer1(x)
        c3 =x = self.layer2(x)
        c4 =x = self.layer3(x)
        c5 = x = self.layer4(x)

        p5 = self.lateral5(c5)
        p4 = self.upsample4(p5) + self.lateral4(c4)
        p3 = self.upsample3(p4) + self.lateral3(c3)

        p6 = self.downsample5(p5)
        p7 = self.downsample6(p6)

        loc3 = self.loc_layer3(p3)
        conf_centerness3 = self.conf_centerness_layer3(p3)
        conf3, centerness3 = conf_centerness3.split([self.num_classes, 1], dim=1)

        loc4 = self.loc_layer4(p4)
        conf_centerness4 = self.conf_centerness_layer4(p4)
        conf4, centerness4 = conf_centerness4.split([self.num_classes, 1], dim=1)

        loc5 = self.loc_layer5(p5)
        conf_centerness5 = self.conf_centerness_layer5(p5)
        conf5, centerness5 = conf_centerness5.split([self.num_classes, 1], dim=1)

        loc6 = self.loc_layer6(p6)
        conf_centerness6 = self.conf_centerness_layer6(p6)
        conf6, centerness6 = conf_centerness6.split([self.num_classes, 1], dim=1)

        loc7 = self.loc_layer7(p7)
        conf_centerness7 = self.conf_centerness_layer7(p7)
        conf7, centerness7 = conf_centerness7.split([self.num_classes, 1], dim=1)

        locs = torch.cat([loc3.permute(0, 2, 3, 1).contiguous().view(loc3.size(0), -1),
                    loc4.permute(0, 2, 3, 1).contiguous().view(loc4.size(0), -1),
                    loc5.permute(0, 2, 3, 1).contiguous().view(loc5.size(0), -1),
                    loc6.permute(0, 2, 3, 1).contiguous().view(loc6.size(0), -1),
                    loc7.permute(0, 2, 3, 1).contiguous().view(loc7.size(0), -1)],dim=1)

        confs = torch.cat([conf3.permute(0, 2, 3, 1).contiguous().view(conf3.size(0), -1),
                           conf4.permute(0, 2, 3, 1).contiguous().view(conf4.size(0), -1),
                           conf5.permute(0, 2, 3, 1).contiguous().view(conf5.size(0), -1),
                           conf6.permute(0, 2, 3, 1).contiguous().view(conf6.size(0), -1),
                           conf7.permute(0, 2, 3, 1).contiguous().view(conf7.size(0), -1),], dim=1)

        centernesses = torch.cat([centerness3.permute(0, 2, 3, 1).contiguous().view(centerness3.size(0), -1),
                           centerness4.permute(0, 2, 3, 1).contiguous().view(centerness4.size(0), -1),
                           centerness5.permute(0, 2, 3, 1).contiguous().view(centerness5.size(0), -1),
                           centerness6.permute(0, 2, 3, 1).contiguous().view(centerness6.size(0), -1),
                           centerness7.permute(0, 2, 3, 1).contiguous().view(centerness7.size(0), -1), ], dim=1)

        out = (locs, confs, centernesses)
        return out

if __name__ == '__main__':
    model = PolarMask()
    print(model)

    input = torch.randn(1, 3, 800, 1024)
    out = model(input)
    print(out[0].shape)
    print(out[1].shape)
    print(out[2].shape)

你可能感兴趣的:(Instance,Segmentation,Paper,Reading,Deep,Learning)