chainer-语义分割-DeepLab【附源码】

文章目录

  • 前言
  • 一、DeepLab官网图
  • 二、代码实现
    • 1.deeplab主干网络实现,以mobilenet为例,后期可自行修改主干网络
    • 2.deepLab网络结构
  • 三、调用方式


前言

本文主要基于deepLab的网络结构搭建
配合文章:语义分割框架链接


一、DeepLab官网图

chainer-语义分割-DeepLab【附源码】_第1张图片
chainer-语义分割-DeepLab【附源码】_第2张图片


二、代码实现

1.deeplab主干网络实现,以mobilenet为例,后期可自行修改主干网络

import chainer
import chainer.functions as F
import chainer.links as L

class Block(chainer.Chain):
    def __init__(self, n_input, n_output, expantion, ksize=3, stride=1, dilate=1, nobias=True):
        super(Block, self).__init__()
        pad = ((ksize - 1) * dilate) // 2
        n_hidden = n_input * expantion
        self.stride = stride
        self.skip_connection = (n_input == n_output)

        with self.init_scope():
            self.conv1 = L.Convolution2D(n_input, n_hidden, ksize=1, nobias=nobias)
            self.bn1 = L.BatchNormalization(n_hidden)
            self.conv2 = L.Convolution2D(n_hidden, n_hidden, ksize=ksize, stride=stride, pad=pad, dilate=dilate, groups=n_hidden, nobias=nobias)
            self.bn2 = L.BatchNormalization(n_hidden)
            self.conv3 = L.Convolution2D(n_hidden, n_output, ksize=1, nobias=nobias)
            self.bn3 = L.BatchNormalization(n_output)

    def forward(self, x):
        h = F.clipped_relu(self.bn1(self.conv1(x)), z=6.)
        h = F.clipped_relu(self.bn2(self.conv2(h)), z=6.)
        h = self.bn3(self.conv3(h))
        if self.stride == 1 and self.skip_connection:
            h = h + x

        return h

class MobileNetV2(chainer.Chain):
    def __init__(self):
        super(MobileNetV2, self).__init__()
        params = [
            (1, 16, 1, 1),
            (6, 24, 2, 2),
            (6, 32, 3, 2),
            (6, 64, 4, 2),
            (6, 96, 3, 1),
            (6, 160, 3, 2),
            (6, 320, 1, 1)
        ]

        with self.init_scope():
            self.conv1 = L.Convolution2D(3, 32, ksize=3, stride=2, pad=1, nobias=True)
            self.bn1 = L.BatchNormalization(32)

            self._forward = []
            n_input = 32
            for i, (t, c, n, s) in enumerate(params):
                for j in range(n):
                    name = 'block{}_{}'.format(i + 1, j + 1)
                    if j == 0:
                        block = Block(n_input, c, t, stride=s)
                        n_input = c
                    else:
                        block = Block(n_input, n_input, t)
                    setattr(self, name, block)
                    self._forward.append(name)

    def forward(self, x):
        h = F.clipped_relu(self.bn1(self.conv1(x)), z=6.)
        for name in self._forward:
            if name == 'block3_1':
                low_level_features = h
            block = getattr(self, name)
            h = block(h)
        return h, low_level_features

2.deepLab网络结构

def backbone_module(backbone, output_stride):
    if backbone == 'resnet':
        raise NotImplementedError
    elif backbone == 'xception':
        return Xception(output_stride=output_stride)
    elif backbone == 'mobilenet':
        return MobileNetV2()
    elif backbone == 'vgg':
        return VGG16()
    else:
        raise ValueError

class ASPPModule(chainer.Chain):
    def __init__(self, n_input, n_output, ksize, pad, dilate):
        super(ASPPModule, self).__init__()
        with self.init_scope():
            self.conv = L.Convolution2D(n_input, n_output, ksize=ksize, pad=pad, dilate=dilate, nobias=True)
            self.bn = L.BatchNormalization(n_output)

    def forward(self, x):
        h = F.relu(self.bn(self.conv(x)))

        return h

class ASPP(chainer.Chain):
    def __init__(self, backbone, output_stride):
        super(ASPP, self).__init__()
        if backbone in ['resnet', 'xception']:
            n_input = 2048
        elif backbone == 'mobilenet':
            n_input = 320
        elif backbone == 'vgg':
            n_input = 512
        else:
            raise ValueError

        if output_stride == 16:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 12, 24, 36]
        else:
            raise ValueError

        with self.init_scope():
            self.aspp1 = ASPPModule(n_input, 256, ksize=1, pad=0, dilate=dilations[0])
            self.aspp2 = ASPPModule(n_input, 256, ksize=3, pad=dilations[1], dilate=dilations[1])
            self.aspp3 = ASPPModule(n_input, 256, ksize=3, pad=dilations[2], dilate=dilations[2])
            self.aspp4 = ASPPModule(n_input, 256, ksize=3, pad=dilations[3], dilate=dilations[3])

            self.conv1 = L.Convolution2D(n_input, 256, ksize=1, nobias=True)
            self.bn1 = L.BatchNormalization(256)

            self.conv2 = L.Convolution2D(1280, 256, ksize=1, nobias=True)
            self.bn2 = L.BatchNormalization(256)

    def forward(self, x):
        h1 = self.aspp1(x)
        h2 = self.aspp2(x)
        h3 = self.aspp3(x)
        h4 = self.aspp4(x)
        h5 = F.average(x, axis=(2, 3), keepdims=True)
        h5 = F.relu(self.bn1(self.conv1(h5)))
        h5 = F.resize_images(h5, h4.shape[2:])
        h = F.concat((h1, h2, h3, h4, h5))

        h = F.relu(self.bn2(self.conv2(h)))
        h = F.dropout(h)

        return h

class Decoder(chainer.Chain):
    def __init__(self, n_class, backbone):
        super(Decoder, self).__init__()
        if backbone == 'resnet':
            # n_low_level_features = 256
            raise NotImplementedError
        elif backbone == 'xception':
            n_low_level_features = 128
        elif backbone == 'mobilenet':
            n_low_level_features = 24
        elif backbone == 'vgg':
            n_low_level_features = 128
        else:
            raise ValueError

        with self.init_scope():
            self.conv1 = L.Convolution2D(n_low_level_features, 48, ksize=1, nobias=True)
            self.bn1 = L.BatchNormalization(48)
            self.conv2 = L.Convolution2D(304, 256, ksize=3, pad=1, nobias=True)
            self.bn2 = L.BatchNormalization(256)
            self.conv3 = L.Convolution2D(256, 256, ksize=3, pad=1, nobias=True)
            self.bn3 = L.BatchNormalization(256)
            self.conv4 = L.Convolution2D(256, n_class, ksize=1)

    def forward(self, x, low_level_features):
        low_level_features = F.relu(self.bn1(self.conv1(low_level_features)))

        h = F.resize_images(x, low_level_features.shape[2:])
        h = F.concat((h, low_level_features))

        h = F.relu(self.bn2(self.conv2(h)))
        h = F.dropout(h, ratio=0.5)
        h = F.relu(self.bn3(self.conv3(h)))
        h = F.dropout(h, ratio=0.1)
        h = self.conv4(h)

        return h


class DeepLab(chainer.Chain):
    def __init__(self, n_class, backbone='xception', output_stride=16):
        super(DeepLab, self).__init__()
        with self.init_scope():
            self.backbone = backbone_module(backbone, output_stride)
            self.aspp = ASPP(backbone, output_stride)
            self.decoder = Decoder(n_class, backbone)

    def forward(self, x):
        h, low_level_features = self.backbone(x)
        h = self.aspp(h)
        h = self.decoder(h, low_level_features)
        if h.shape != x.shape:
            h = F.resize_images(h, x.shape[2:])

        return h

三、调用方式

model = DeepLab(n_class=len(self.classes_names), backbone=backbone)
self.model = ModifiedClassifier(model, lossfun=F.softmax_cross_entropy)
if self.gpu_devices>=0:
    chainer.cuda.get_device_from_id(self.gpu_devices).use()
    self.model.to_gpu()

你可能感兴趣的:(深度学习-chainer,深度学习,python,人工智能)