本文主要基于deepLab的网络结构搭建
配合文章:语义分割框架链接
import chainer
import chainer.functions as F
import chainer.links as L
class Block(chainer.Chain):
def __init__(self, n_input, n_output, expantion, ksize=3, stride=1, dilate=1, nobias=True):
super(Block, self).__init__()
pad = ((ksize - 1) * dilate) // 2
n_hidden = n_input * expantion
self.stride = stride
self.skip_connection = (n_input == n_output)
with self.init_scope():
self.conv1 = L.Convolution2D(n_input, n_hidden, ksize=1, nobias=nobias)
self.bn1 = L.BatchNormalization(n_hidden)
self.conv2 = L.Convolution2D(n_hidden, n_hidden, ksize=ksize, stride=stride, pad=pad, dilate=dilate, groups=n_hidden, nobias=nobias)
self.bn2 = L.BatchNormalization(n_hidden)
self.conv3 = L.Convolution2D(n_hidden, n_output, ksize=1, nobias=nobias)
self.bn3 = L.BatchNormalization(n_output)
def forward(self, x):
h = F.clipped_relu(self.bn1(self.conv1(x)), z=6.)
h = F.clipped_relu(self.bn2(self.conv2(h)), z=6.)
h = self.bn3(self.conv3(h))
if self.stride == 1 and self.skip_connection:
h = h + x
return h
class MobileNetV2(chainer.Chain):
def __init__(self):
super(MobileNetV2, self).__init__()
params = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1)
]
with self.init_scope():
self.conv1 = L.Convolution2D(3, 32, ksize=3, stride=2, pad=1, nobias=True)
self.bn1 = L.BatchNormalization(32)
self._forward = []
n_input = 32
for i, (t, c, n, s) in enumerate(params):
for j in range(n):
name = 'block{}_{}'.format(i + 1, j + 1)
if j == 0:
block = Block(n_input, c, t, stride=s)
n_input = c
else:
block = Block(n_input, n_input, t)
setattr(self, name, block)
self._forward.append(name)
def forward(self, x):
h = F.clipped_relu(self.bn1(self.conv1(x)), z=6.)
for name in self._forward:
if name == 'block3_1':
low_level_features = h
block = getattr(self, name)
h = block(h)
return h, low_level_features
def backbone_module(backbone, output_stride):
if backbone == 'resnet':
raise NotImplementedError
elif backbone == 'xception':
return Xception(output_stride=output_stride)
elif backbone == 'mobilenet':
return MobileNetV2()
elif backbone == 'vgg':
return VGG16()
else:
raise ValueError
class ASPPModule(chainer.Chain):
def __init__(self, n_input, n_output, ksize, pad, dilate):
super(ASPPModule, self).__init__()
with self.init_scope():
self.conv = L.Convolution2D(n_input, n_output, ksize=ksize, pad=pad, dilate=dilate, nobias=True)
self.bn = L.BatchNormalization(n_output)
def forward(self, x):
h = F.relu(self.bn(self.conv(x)))
return h
class ASPP(chainer.Chain):
def __init__(self, backbone, output_stride):
super(ASPP, self).__init__()
if backbone in ['resnet', 'xception']:
n_input = 2048
elif backbone == 'mobilenet':
n_input = 320
elif backbone == 'vgg':
n_input = 512
else:
raise ValueError
if output_stride == 16:
dilations = [1, 6, 12, 18]
elif output_stride == 8:
dilations = [1, 12, 24, 36]
else:
raise ValueError
with self.init_scope():
self.aspp1 = ASPPModule(n_input, 256, ksize=1, pad=0, dilate=dilations[0])
self.aspp2 = ASPPModule(n_input, 256, ksize=3, pad=dilations[1], dilate=dilations[1])
self.aspp3 = ASPPModule(n_input, 256, ksize=3, pad=dilations[2], dilate=dilations[2])
self.aspp4 = ASPPModule(n_input, 256, ksize=3, pad=dilations[3], dilate=dilations[3])
self.conv1 = L.Convolution2D(n_input, 256, ksize=1, nobias=True)
self.bn1 = L.BatchNormalization(256)
self.conv2 = L.Convolution2D(1280, 256, ksize=1, nobias=True)
self.bn2 = L.BatchNormalization(256)
def forward(self, x):
h1 = self.aspp1(x)
h2 = self.aspp2(x)
h3 = self.aspp3(x)
h4 = self.aspp4(x)
h5 = F.average(x, axis=(2, 3), keepdims=True)
h5 = F.relu(self.bn1(self.conv1(h5)))
h5 = F.resize_images(h5, h4.shape[2:])
h = F.concat((h1, h2, h3, h4, h5))
h = F.relu(self.bn2(self.conv2(h)))
h = F.dropout(h)
return h
class Decoder(chainer.Chain):
def __init__(self, n_class, backbone):
super(Decoder, self).__init__()
if backbone == 'resnet':
# n_low_level_features = 256
raise NotImplementedError
elif backbone == 'xception':
n_low_level_features = 128
elif backbone == 'mobilenet':
n_low_level_features = 24
elif backbone == 'vgg':
n_low_level_features = 128
else:
raise ValueError
with self.init_scope():
self.conv1 = L.Convolution2D(n_low_level_features, 48, ksize=1, nobias=True)
self.bn1 = L.BatchNormalization(48)
self.conv2 = L.Convolution2D(304, 256, ksize=3, pad=1, nobias=True)
self.bn2 = L.BatchNormalization(256)
self.conv3 = L.Convolution2D(256, 256, ksize=3, pad=1, nobias=True)
self.bn3 = L.BatchNormalization(256)
self.conv4 = L.Convolution2D(256, n_class, ksize=1)
def forward(self, x, low_level_features):
low_level_features = F.relu(self.bn1(self.conv1(low_level_features)))
h = F.resize_images(x, low_level_features.shape[2:])
h = F.concat((h, low_level_features))
h = F.relu(self.bn2(self.conv2(h)))
h = F.dropout(h, ratio=0.5)
h = F.relu(self.bn3(self.conv3(h)))
h = F.dropout(h, ratio=0.1)
h = self.conv4(h)
return h
class DeepLab(chainer.Chain):
def __init__(self, n_class, backbone='xception', output_stride=16):
super(DeepLab, self).__init__()
with self.init_scope():
self.backbone = backbone_module(backbone, output_stride)
self.aspp = ASPP(backbone, output_stride)
self.decoder = Decoder(n_class, backbone)
def forward(self, x):
h, low_level_features = self.backbone(x)
h = self.aspp(h)
h = self.decoder(h, low_level_features)
if h.shape != x.shape:
h = F.resize_images(h, x.shape[2:])
return h
model = DeepLab(n_class=len(self.classes_names), backbone=backbone)
self.model = ModifiedClassifier(model, lossfun=F.softmax_cross_entropy)
if self.gpu_devices>=0:
chainer.cuda.get_device_from_id(self.gpu_devices).use()
self.model.to_gpu()