本文主要基于LinkNet的网络结构搭建
配合文章:语义分割框架链接
def parse_dict(dic, key, value=None):
return value if not key in dic else dic[key]
class Conv(chainer.Chain):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None):
super(Conv, self).__init__()
with self.init_scope():
if upsample:
self.conv = L.Deconvolution2D(in_ch, out_ch, ksize, stride, pad, nobias=nobias, outsize=outsize)
else:
if dilation > 1:
self.conv = L.DilatedConvolution2D(in_ch, out_ch, ksize, stride, pad, dilation, nobias=nobias)
else:
self.conv = L.Convolution2D(in_ch, out_ch, ksize, stride, pad, nobias=nobias)
def __call__(self, x):
return self.conv(x)
def predict(self, x):
return self.conv(x)
class ConvBN(Conv):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None):
super(ConvBN, self).__init__(in_ch, out_ch, ksize, stride, pad, dilation, nobias, upsample, outsize)
self.add_link("bn", L.BatchNormalization(out_ch, eps=1e-5, decay=0.95))
def __call__(self, x):
return self.bn(self.conv(x))
def predict(self, x):
return self.bn(self.conv(x))
class ConvReLU(Conv):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None):
super(ConvReLU, self).__init__(in_ch, out_ch, ksize, stride, pad, dilation, nobias, upsample, outsize)
def __call__(self, x):
return F.relu(self.conv(x))
def predict(self, x):
return F.relu(self.conv(x))
class ConvBNReLU(ConvBN):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None):
super(ConvBNReLU, self).__init__(in_ch, out_ch, ksize, stride, pad, dilation, nobias, upsample, outsize)
def __call__(self, x):
return F.relu(self.bn(self.conv(x)))
def predict(self, x):
return F.relu(self.bn(self.conv(x)))
class ConvPReLU(Conv):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None):
super(ConvPReLU, self).__init__(in_ch, out_ch, ksize, stride, pad, dilation, nobias, upsample, outsize)
self.add_link("prelu", L.PReLU())
def __call__(self, x):
return self.prelu(self.conv(x))
def predict(self, x):
return self.prelu(self.conv(x))
class ConvBNPReLU(ConvBN):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None):
super(ConvBNPReLU, self).__init__(in_ch, out_ch, ksize, stride, pad, dilation, nobias, upsample, outsize)
self.add_link("prelu", L.PReLU())
def __call__(self, x):
return self.prelu(self.bn(self.conv(x)))
def predict(self, x):
return self.prelu(self.bn(self.conv(x)))
class SymmetricConvPReLU(chainer.Chain):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=None):
super(SymmetricConvPReLU, self).__init__()
with self.init_scope():
self.conv1 = L.Convolution2D(in_ch, out_ch, (ksize, 1), stride, pad, nobias=nobias)
self.conv2 = L.Convolution2D(in_ch, out_ch, (1, ksize), stride, pad, nobias=nobias)
self.prelu = L.PReLU()
def __call__(self, x):
return self.prelu(self.conv2(self.conv1(x)))
def predict(self, x):
return self.prelu(self.conv2(self.conv1(x)))
class SymmetricConvBNPReLU(SymmetricConvPReLU):
def __init__(self, in_ch, out_ch, ksize, stride=1, pad=1, dilation=1, nobias=False, upsample=None):
super(SymmetricConvBNPReLU, self).__init__(in_ch, out_ch, ksize, stride, pad, dilation, nobias, upsample)
self.add_link("bn", L.BatchNormalization(out_ch, eps=1e-5, decay=0.95))
def __call__(self, x):
h = self.conv2(self.conv1(x))
return self.prelu(self.bn(h))
def predict(self, x):
h = self.conv2(self.conv1(x))
return self.prelu(self.bn(h))
class InitialBlock(chainer.Chain):
def __init__(self, in_ch=3, out_ch=13, ksize=3, stride=2, pad=1, nobias=False, psize=3, use_bn=True, use_prelu=False):
super(InitialBlock, self).__init__()
with self.init_scope():
this_mod = sys.modules[__name__]
conv_type = "ConvBN" if use_bn else "Conv"
activation = "PReLU" if use_prelu else "ReLU"
conv_type = conv_type + activation
ConvBlock = getattr(this_mod, conv_type)
self.conv = ConvBlock(in_ch, out_ch, ksize, stride, pad, nobias=nobias)
self.psize = psize
self.ppad = int((psize - 1) / 2)
def __call__(self, x):
x = self.conv(x)
return F.max_pooling_2d(x, self.psize, 2)
def predict(self, x):
x = self.conv(x)
return F.max_pooling_2d(x, self.psize, 2, self.ppad)
class ResBacisBlock(chainer.Chain):
def __init__(self, in_ch=3, out_ch=13, downsample=False, use_bn=True):
super(ResBacisBlock, self).__init__()
self.downsample = downsample
with self.init_scope():
this_mod = sys.modules[__name__]
conv_type = "ConvBN" if use_bn else "Conv"
ConvBlock = getattr(this_mod, conv_type + "ReLU")
stride = 2 if self.downsample else 1
self.conv1 = ConvBlock(in_ch, out_ch, 3, stride, 1, nobias=True)
ConvBlock = getattr(this_mod, conv_type)
self.conv2 = ConvBlock(out_ch, out_ch, 3, 1, 1, nobias=True)
if self.downsample:
self.conv3 = ConvBlock(in_ch, out_ch, 1, 2, 0, nobias=True)
def __call__(self, x):
h1 = self.conv2(self.conv1(x))
if self.downsample:
return F.relu(h1 + self.conv3(x))
else:
return F.relu(h1 + x)
def predict(self, x):
h1 = self.conv2(self.conv1(x))
if self.downsample:
return F.relu(h1 + self.conv3(x))
else:
return F.relu(h1 + x)
class ResBlock18(chainer.Chain):
def __init__(self, use_bn=True, train=True,alpha=1):
super(ResBlock18, self).__init__()
self.alpha=alpha
self.train = train
with self.init_scope():
self.block1_1 = ResBacisBlock(64//self.alpha, 64//self.alpha, use_bn=use_bn)
self.block1_2 = ResBacisBlock(64//self.alpha, 64//self.alpha, use_bn=use_bn)
self.block2_1 = ResBacisBlock(64//self.alpha, 128//self.alpha, use_bn=use_bn, downsample=True)
self.block2_2 = ResBacisBlock(128//self.alpha, 128//self.alpha, use_bn=use_bn)
self.block3_1 = ResBacisBlock(128//self.alpha, 256//self.alpha, use_bn=use_bn, downsample=True)
self.block3_2 = ResBacisBlock(256//self.alpha, 256//self.alpha, use_bn=use_bn)
self.block4_1 = ResBacisBlock(256//self.alpha, 512//self.alpha, use_bn=use_bn, downsample=True)
self.block4_2 = ResBacisBlock(512//self.alpha, 512//self.alpha, use_bn=use_bn)
def pytorch2chainer(path):
pass
def __call__(self, x):
with chainer.using_config('train', self.train):
h1 = self.block1_2(self.block1_1(x))
h2 = self.block2_2(self.block2_1(h1))
h3 = self.block3_2(self.block3_1(h2))
h4 = self.block4_2(self.block4_1(h3))
return h1, h2, h3, h4
def predict(self, x):
x = self.block1_2(self.block1_1(x))
x = self.block2_2(self.block2_1(x))
x = self.block3_2(self.block3_1(x))
return self.block4_2(self.block4_1(x))
class ResBlock34(chainer.Chain):
def __init__(self, use_bn=True, train=True,alpha=1):
super(ResBlock34, self).__init__()
self.alpha=alpha
self.train = train
with self.init_scope():
self.block1_1 = ResBacisBlock(64//self.alpha, 64//self.alpha, use_bn=use_bn)
self.block1_2 = ResBacisBlock(64//self.alpha, 64//self.alpha, use_bn=use_bn)
self.block1_3 = ResBacisBlock(64//self.alpha, 64//self.alpha, use_bn=use_bn)
self.block2_1 = ResBacisBlock(64//self.alpha, 128//self.alpha, use_bn=use_bn, downsample=True)
self.block2_2 = ResBacisBlock(128//self.alpha, 128//self.alpha, use_bn=use_bn)
self.block2_3 = ResBacisBlock(128//self.alpha, 128//self.alpha, use_bn=use_bn)
self.block2_4 = ResBacisBlock(128//self.alpha, 128//self.alpha, use_bn=use_bn)
self.block3_1 = ResBacisBlock(128//self.alpha, 256//self.alpha, use_bn=use_bn, downsample=True)
self.block3_2 = ResBacisBlock(256//self.alpha, 256//self.alpha, use_bn=use_bn)
self.block3_3 = ResBacisBlock(256//self.alpha, 256//self.alpha, use_bn=use_bn)
self.block3_4 = ResBacisBlock(256//self.alpha, 256//self.alpha, use_bn=use_bn)
self.block3_5 = ResBacisBlock(256//self.alpha, 256//self.alpha, use_bn=use_bn)
self.block3_6 = ResBacisBlock(256//self.alpha, 256//self.alpha, use_bn=use_bn)
self.block4_1 = ResBacisBlock(256//self.alpha, 512//self.alpha, use_bn=use_bn, downsample=True)
self.block4_2 = ResBacisBlock(512//self.alpha, 512//self.alpha, use_bn=use_bn)
self.block4_3 = ResBacisBlock(512//self.alpha, 512//self.alpha, use_bn=use_bn)
def pytorch2chainer(path):
pass
def __call__(self, x):
with chainer.using_config('train', self.train):
h1 = self.block1_2(self.block1_1(x))
h2 = self.block2_2(self.block2_1(h1))
h3 = self.block3_2(self.block3_1(h2))
h4 = self.block4_2(self.block4_1(h3))
return h1, h2, h3, h4
def predict(self, x):
x = self.block1_2(self.block1_1(x))
x = self.block2_2(self.block2_1(x))
x = self.block3_2(self.block3_1(x))
return self.block4_2(self.block4_1(x))
class DecoderBlock(chainer.Chain):
def __init__(self, in_ch=3, mid_ch=0, out_ch=13, ksize=3, stride=1, pad=1, residual=False, nobias=False, outsize=None, upsample=False, use_bn=True, use_prelu=False):
super(DecoderBlock, self).__init__()
self.residual = residual
mid_ch = int(in_ch / 4)
with self.init_scope():
this_mod = sys.modules[__name__]
conv_type = "ConvBN" if use_bn else "Conv"
activation = "PReLU" if use_prelu else "ReLU"
ConvBlock = getattr(this_mod, conv_type + activation)
self.conv1 = ConvBlock(in_ch, mid_ch, 1, 1, 0, nobias=True)
conv_type2 = conv_type + activation
ConvBlock = getattr(this_mod, conv_type2)
self.conv2 = ConvBlock(mid_ch, mid_ch, ksize, stride, pad, nobias=False, upsample=upsample, outsize=None) # outsize)
ConvBlock = getattr(this_mod, conv_type)
self.conv3 = ConvBlock(mid_ch, out_ch, 1, 1, 0, nobias=True)
def __call__(self, x):
h1 = self.conv1(x)
h1 = self.conv2(h1)
h1 = self.conv3(h1)
if self.residual:
return F.relu(h1 + x)
return F.relu(h1)
def predict(self, x):
h1 = self.conv1(x)
h1 = self.conv2(h1)
h1 = self.conv3(h1)
if self.residual:
return F.relu(h1 + x)
return F.relu(h1)
class FullConv(chainer.Chain):
def __init__(self, in_ch=3, mid_ch=0, out_ch=13, ksize=3, stride=1, pad=1):
super(FullConv, self).__init__()
with self.init_scope():
self.deconv = L.Deconvolution2D(in_ch, out_ch, ksize, stride, pad)
def __call__(self, x):
return self.deconv(x)
def predict(self, x):
return self.deconv(x)
class LinkNetBasic(chainer.Chain):
def __init__(self,n_layers=18, n_class = None,image_size=512,alpha=1):
super(LinkNetBasic, self).__init__()
self.alpha=alpha
self.size = (image_size,image_size)
with self.init_scope():
self.initial_block = InitialBlock(in_ch=3, out_ch=64//self.alpha, ksize=7, stride=2, pad=3, nobias=True, psize=3, use_bn=True, use_prelu=False)
if n_layers==18:
self.resblock = ResBlock18(use_bn=True, train=True,alpha=self.alpha)
else:
self.resblock = ResBlock34(use_bn=True, train=True,alpha=self.alpha)
self.decoder4 = DecoderBlock(in_ch=512//self.alpha, mid_ch=0, out_ch=256//self.alpha, ksize=2, stride=2, pad=0, residual=False, nobias=False, outsize=(self.size[0]/(2**5), self.size[1]/(2**5)), upsample=1, use_bn=True, use_prelu=False)
self.decoder3 = DecoderBlock(in_ch=256//self.alpha, mid_ch=0, out_ch=128//self.alpha, ksize=2, stride=2, pad=0, residual=False, nobias=False, outsize=(self.size[0]/(2**4), self.size[1]/(2**4)), upsample=1, use_bn=True, use_prelu=False)
self.decoder2 = DecoderBlock(in_ch=128//self.alpha, mid_ch=0, out_ch=64//self.alpha, ksize=2, stride=2, pad=0, residual=False, nobias=False, outsize=(self.size[0]/(2**3), self.size[1]/(2**3)), upsample=1, use_bn=True, use_prelu=False)
self.decoder1 = DecoderBlock(in_ch=64//self.alpha, mid_ch=0, out_ch=64//self.alpha, ksize=3, stride=1, pad=1, residual=False, nobias=False, outsize= (self.size[0]/(2**2), self.size[1]/(2**2)), upsample=False, use_bn=True, use_prelu=False)
self.finalblock1 = ConvBNReLU(in_ch=64//self.alpha, out_ch=32//self.alpha, ksize=2, stride=2, pad=0, dilation=1, nobias=False, upsample=True, outsize=None)
self.finalblock2 = ConvBNReLU(in_ch=32//self.alpha, out_ch=32//self.alpha, ksize=3, stride=1, pad=1, dilation=1, nobias=False, upsample=False, outsize=None)
self.finalblock3 = FullConv(in_ch=32//self.alpha, mid_ch=0, out_ch=n_class, ksize=2, stride=2, pad=0)
def __call__(self, x):
x = self.initial_block(x)
h1, h2, h3, h4 = self.resblock(x)
x = self.decoder4(h4)
x += h3
x = self.decoder3(x)
x += h2
x = self.decoder2(x)
x += h1
x = self.decoder1(x)
x = self.finalblock1(x)
x = self.finalblock2(x)
x = self.finalblock3(x)
return x
model = LinkNetBasic(n_layers=self.layers,n_class = len(self.classes_names),image_size=self.image_size,alpha=self.alpha)
self.model = PixelwiseSoftmaxClassifier(model)
if self.gpu_devices>=0:
chainer.cuda.get_device_from_id(self.gpu_devices).use()
self.model.to_gpu()