本文基于Mxnet实现OCR的网络结构
class AttModel(nn.HybridBlock):
def __init__(self, encoder, decoder, start_symbol=0, end_symbol=1, **kwargs):
super(AttModel, self).__init__(**kwargs)
self.start_symbol = start_symbol
self.end_symbol = end_symbol
with self.name_scope():
self.encoder = encoder
self.decoder = decoder
def hybrid_forward(self, F, x, mask, targ_input):
if isinstance(x, mx.ndarray.NDArray):
batch_size = x.shape[0]
state = self.begin_state(func=mx.ndarray.zeros, ctx=x.context, batch_size=batch_size, dtype=x.dtype)
else:
state = self.begin_state(func=mx.symbol.zeros)
en_out, en_proj, mask = self.encoder(x, mask)
states = [en_out, en_proj, mask, state]
tag_input = F.transpose(targ_input, axes=(1, 0)).expand_dims(axis=-1)
def train_func(out, states):
outs = self.decoder(out, states)
return outs[0], outs[1:]
if autograd.is_training():
outputs, states = F.contrib.foreach(train_func, tag_input, states)
outputs = F.reshape(outputs, shape=(0, -3, 0))
outputs = F.transpose(outputs, axes=(1, 0, 2))
return outputs
def test_func(inp, states):
outs = self.decoder(*states)
pred = F.softmax(outs[0], axis=-1)
pred = F.argmax(pred, axis=-1)
states = [pred * inp] + list(outs[1:])
return pred, states
first_input = F.slice_axis(tag_input, axis=0, begin=0, end=1)
first_input = F.squeeze(first_input, axis=0) * self.start_symbol
states = [first_input] + states
outputs, states = F.contrib.foreach(test_func, tag_input, states)
outputs = F.reshape(outputs, (0, -1))
outputs = F.transpose(outputs, axes=(1, 0))
return outputs
def __call__(self, data, mask, targ_input):
return super(AttModel, self).__call__(data, mask, targ_input)
def export_block(self, prefix, param_path, ctx=mx.cpu()):
if not isinstance(ctx, list):
ctx = [ctx]
data = mx.nd.ones((1, 3, 32, 128), ctx=ctx[0])
mask = mx.nd.ones((1, 1, 1, 16), ctx=ctx[0])
targ_input = mx.nd.ones((1, 10), ctx=ctx[0])
self.hybridize()
# self.load_parameters(param_path)
self.initialize()
self.collect_params().reset_ctx(ctx)
outs = self.__call__(data, mask, targ_input)
self.export(prefix)
def begin_state(self, *args, **kwargs):
return self.decoder.lstm.begin_state(*args, **kwargs)
import mxnet as mx
from mxnet import autograd
from mxnet import gluon
from mxnet.gluon import nn
from nets.anchor import CLRSAnchorGenerator
from nets.backbone.resnet import get_resnet
from nets.backbone.mobilenetv3 import get_mobilenet_v3
from nets.backbone.resnext import get_resnext
class MultiPerClassDecoder(gluon.HybridBlock):
def __init__(self, num_class, axis=-1, thresh=0.01):
super(MultiPerClassDecoder, self).__init__()
self._fg_class = num_class - 1
self._axis = axis
self._thresh = thresh
def hybrid_forward(self, F, x):
scores = x.slice_axis(axis=self._axis, begin=1, end=None) # b x N x fg_class
template = F.zeros_like(x.slice_axis(axis=-1, begin=0, end=1))
cls_id = F.broadcast_add(template, F.reshape(F.arange(self._fg_class), shape=(1, 1, self._fg_class)))
mask = scores > self._thresh
cls_id = F.where(mask, cls_id, F.ones_like(cls_id) * -1)
scores = F.where(mask, scores, F.zeros_like(scores))
return cls_id, scores
class BBoxCornerToCenter(gluon.HybridBlock):
def __init__(self, axis=-1, split=False):
super(BBoxCornerToCenter, self).__init__()
self._split = split
self._axis = axis
def hybrid_forward(self, F, x):
xmin, ymin, xmax, ymax = F.split(x, axis=self._axis, num_outputs=4)
width = xmax - xmin
height = ymax - ymin
x = xmin + width * 0.5
y = ymin + height * 0.5
if not self._split:
return F.concat(x, y, width, height, dim=self._axis)
else:
return x, y, width, height
class NormalizedBoxCenterDecoder(gluon.HybridBlock):
def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), convert_anchor=False, clip=None):
super(NormalizedBoxCenterDecoder, self).__init__()
assert len(stds) == 4, "Box Encoder requires 4 std values."
self._stds = stds
self._clip = clip
if convert_anchor:
self.corner_to_center = BBoxCornerToCenter(split=True)
else:
self.corner_to_center = None
self._format = 'corner' if convert_anchor else 'center'
def hybrid_forward(self, F, x, anchors):
if 'box_decode' in F.contrib.__dict__:
x, anchors = F.amp_multicast(x, anchors, num_outputs=2, cast_narrow=True)
if self._clip is None:
self._clip = -1
return F.contrib.box_decode(x, anchors, self._stds[0], self._stds[1], self._stds[2], self._stds[3], clip=self._clip, format=self._format)
if self.corner_to_center is not None:
a = self.corner_to_center(anchors)
else:
a = anchors.split(axis=-1, num_outputs=4)
p = F.split(x, axis=-1, num_outputs=4)
ox = F.broadcast_add(F.broadcast_mul(p[0] * self._stds[0], a[2]), a[0])
oy = F.broadcast_add(F.broadcast_mul(p[1] * self._stds[1], a[3]), a[1])
dw = p[2] * self._stds[2]
dh = p[3] * self._stds[3]
if self._clip:
dw = F.minimum(dw, self._clip)
dh = F.minimum(dh, self._clip)
dw = F.exp(dw)
dh = F.exp(dh)
ow = F.broadcast_mul(dw, a[2]) * 0.5
oh = F.broadcast_mul(dh, a[3]) * 0.5
return F.concat(ox - ow, oy - oh, ox + ow, oy + oh, dim=-1)
class DM(nn.HybridBlock):
def __init__(self, channels=128, ksize=2, strides=2, pad=0, **kwargs):
super(DM, self).__init__(**kwargs)
with self.name_scope():
self.deconv = nn.HybridSequential()
self.deconv.add(nn.Conv2DTranspose(channels, ksize, strides))
self.deconv.add(nn.Conv2D(channels, 3, 1, 1))
self.deconv.add(nn.BatchNorm())
self.conv = nn.HybridSequential()
self.conv.add(nn.Conv2D(channels, 3, 1, 1))
self.conv.add(nn.BatchNorm())
self.conv.add(nn.Activation('relu'))
self.conv.add(nn.Conv2D(channels, 3, 1, 1))
self.conv.add(nn.BatchNorm())
def hybrid_forward(self, F, x1, x2):
x1 = self.deconv(x1)
x2 = self.conv(x2)
return F.relu(x1 * x2)
class PM(nn.HybridBlock):
def __init__(self, channels=256, k=4, num_classes=4, **kwargs):
super(PM, self).__init__(**kwargs)
with self.name_scope():
self.skip = nn.HybridSequential()
self.skip.add(nn.Conv2D(channels, 1, 1))
self.bone = nn.HybridSequential()
self.bone.add(nn.Conv2D(channels, 1, 1))
self.bone.add(nn.Conv2D(channels, 1, 1))
self.bone.add(nn.Conv2D(channels, 1, 1))
self.conf = nn.Conv2D(k * (num_classes + 1), 3, 1, 1)
self.loc = nn.Conv2D(k * 4, 3, 1, 1)
def hybrid_forward(self, F, x):
x1 = self.skip(x)
x2 = self.bone(x)
x = F.relu(x1 + x2)
score = self.conf(x)
offset = self.loc(x)
return score, offset
class SM(nn.HybridBlock):
def __init__(self, channels, n_scale=2, **kwargs):
super(SM, self).__init__(**kwargs)
self.n_scale = n_scale
with self.name_scope():
self.skip = nn.HybridSequential()
self.skip.add(nn.Conv2D(channels, 1, 1))
self.skip.add(nn.BatchNorm())
self.bone = nn.HybridSequential()
self.bone.add(nn.Conv2D(channels, 1, 1))
self.bone.add(nn.BatchNorm())
self.bone.add(nn.Activation('relu'))
self.bone.add(nn.Conv2D(channels, 1, 1))
self.bone.add(nn.BatchNorm())
self.bone.add(nn.Activation('relu'))
self.bone.add(nn.Conv2D(channels, 1, 1))
self.bone.add(nn.BatchNorm())
def hybrid_forward(self, F, x):
x1 = self.skip(x)
x2 = self.bone(x)
x = F.relu(x1 + x2)
if self.n_scale > 1:
x = F.UpSampling(x, scale=self.n_scale, sample_type='nearest')
return x
class SegPred(nn.HybridBlock):
def __init__(self, channels, **kwargs):
super(SegPred, self).__init__(**kwargs)
with self.name_scope():
self.sms = nn.HybridSequential()
self.sms.add(SM(channels, 16))
self.sms.add(SM(channels, 8))
self.sms.add(SM(channels, 4))
self.sms.add(SM(channels, 2))
self.sms.add(SM(channels, 1))
self.tail = nn.HybridSequential()
self.tail.add(nn.Conv2D(channels, 1, 1))
self.tail.add(nn.BatchNorm())
self.tail.add(nn.Activation('relu'))
self.tail.add(nn.Conv2DTranspose(channels, 2, 2))
self.tail.add(nn.Conv2D(channels, 3, 1, 1))
self.tail.add(nn.BatchNorm())
self.tail.add(nn.Activation('relu'))
self.tail.add(nn.Conv2DTranspose(4, 2, 2))
self.tail.add(nn.Activation('sigmoid'))
def hybrid_forward(self, F, *xs):
feats = []
for x, block in zip(xs, self.sms):
feat = block(x)
feats.append(feat)
fuse_feat = F.relu(F.add_n(*feats))
return self.tail(fuse_feat)
class CLRS(nn.HybridBlock):
def __init__(self, stages, sizes, ratios, steps, dm_channels=256,
pm_channels=256, sm_channels=32,
stds=(0.1, 0.1, 0.2, 0.2), nms_thresh=0.45,
nms_topk=1000, post_nms=400,
anchor_alloc_size=256, **kwargs):
super(CLRS, self).__init__(**kwargs)
self.nms_thresh = nms_thresh
self.nms_topk = nms_topk
self.post_nms = post_nms
with self.name_scope():
self.stages = nn.HybridSequential()
for i in range(len(stages)):
self.stages.add(stages[i])
# extra layers
self.extras = nn.HybridSequential()
self.extras.add(self._extra_layer(256, 512))
self.extras.add(self._extra_layer(128, 256))
self.extras.add(self._extra_layer(128, 256, strides=1))
self.extras.add(self._extra_layer(128, 256, strides=1))
self.dms = nn.HybridSequential()
for i in range(6):
strides = 2 if i > 1 else 1
ksize = 2 if strides == 2 else 3
self.dms.add(DM(dm_channels, ksize, strides=strides, pad=ksize - 2))
self.pms = nn.HybridSequential()
self.anchor_generators = nn.HybridSequential()
asz = anchor_alloc_size
for i, (s, r, st) in enumerate(zip(sizes, ratios, steps)):
self.pms.add(PM(pm_channels, len(s)))
anchor_generator = CLRSAnchorGenerator(i, (512, 512), s, r, st, alloc_size=(asz, asz))
self.anchor_generators.add(anchor_generator)
asz = max(asz // 2, 16)
self.seg_pred = SegPred(sm_channels)
self.bbox_decoder = NormalizedBoxCenterDecoder(stds)
self.cls_decoder = MultiPerClassDecoder(4 + 1, thresh=0.01)
def _extra_layer(self, in_channels, out_channels, strides=2, norm_layer=nn.BatchNorm, norm_kwargs=None):
layer = nn.HybridSequential()
layer.add(nn.Conv2D(in_channels, 1, 1))
layer.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
layer.add(nn.Activation('relu'))
layer.add(nn.Conv2D(out_channels, 3, strides, strides - 1))
layer.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
layer.add(nn.Activation('relu'))
return layer
def set_nms(self, nms_thresh=0.45, nms_topk=400, post_nms=100):
self._clear_cached_op()
self.nms_thresh = nms_thresh
self.nms_topk = nms_topk
self.post_nms = post_nms
def hybrid_forward(self, F, x):
feats = []
for block in self.stages:
x = block(x)
feats.append(x)
for block in self.extras:
x = block(x)
feats.append(x)
F11 = feats[-1]
F10 = self.dms[0](F11, feats[-2])
F9 = self.dms[1](F10, feats[-3])
F8 = self.dms[2](F9, feats[-4])
F7 = self.dms[3](F8, feats[-5])
F4 = self.dms[4](F7, feats[-6])
F3 = self.dms[5](F4, feats[-7])
feats = [F3, F4, F7, F8, F9, F10, F11]
box_preds, cls_preds = [], []
for feat, block in zip(feats, self.pms):
pred = block(feat)
cls_preds.append(F.flatten(F.transpose(pred[0], (0, 2, 3, 1))))
box_preds.append(F.flatten(F.transpose(pred[1], (0, 2, 3, 1))))
cls_preds = F.concat(*cls_preds, dim=1).reshape((0, -1, 5))
box_preds = F.concat(*box_preds, dim=1).reshape((0, -1, 4))
anchors = [F.reshape(ag(feat), shape=(1, -1))
for feat, ag in zip(feats, self.anchor_generators)]
anchors = F.concat(*anchors, dim=1).reshape((1, -1, 4))
seg_maps = self.seg_pred(F9, F8, F7, F4, F3)
if autograd.is_training():
return [cls_preds, box_preds, anchors, seg_maps]
bboxes = self.bbox_decoder(box_preds, anchors)
cls_ids, scores = self.cls_decoder(F.softmax(cls_preds, axis=-1))
results = []
for i in range(4):
cls_id = cls_ids.slice_axis(axis=-1, begin=i, end=i + 1)
score = scores.slice_axis(axis=-1, begin=i, end=i + 1)
# per class results
per_result = F.concat(*[cls_id, score, bboxes], dim=-1)
results.append(per_result)
result = F.concat(*results, dim=1)
if self.nms_thresh > 0 and self.nms_thresh < 1:
result = F.contrib.box_nms(
result, overlap_thresh=self.nms_thresh, topk=self.nms_topk, valid_thresh=0.01,
id_index=0, score_index=1, coord_start=2, force_suppress=False)
if self.post_nms > 0:
result = result.slice_axis(axis=1, begin=0, end=self.post_nms)
ids = F.slice_axis(result, axis=2, begin=0, end=1)
scores = F.slice_axis(result, axis=2, begin=1, end=2)
bboxes = F.slice_axis(result, axis=2, begin=2, end=6)
return ids, scores, bboxes, seg_maps
def export_block(self, prefix, param_path, ctx=mx.cpu()):
if not isinstance(ctx, list):
ctx = [ctx]
data = mx.nd.ones((1, 3, 512, 512), dtype='float32', ctx=ctx[0])
self.load_parameters(param_path)
self.set_nms(0.45, 1000, 400)
self.hybridize()
self.collect_params().reset_ctx(ctx)
pred1 = self(data)
self.export(prefix, epoch=0)
def get_clrs(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None):
if 'resnet' in backbone_name.lower():
if 'v1' in backbone_name.lower():
version = 1
else:
version = 2
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnet(version, num_layers, strides=[(2, 2), (1, 1), (2, 2), (2, 2), (1, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
ids = [5, 6, 8]
elif 'resnext' in backbone_name.lower():
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnext(num_layers, strides=[(2, 2), (1, 1), (2, 2), (2, 2), (1, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
ids = [5, 6, 8]
elif 'mobilenetv3' in backbone_name.lower():
model_name = int(backbone_name.lower().split('_')[1])
ids = [4, 6, 17] if model_name == 'small' else [6, 9, 21]
if model_name == 'small':
base_net = get_mobilenet_v3('small', strides=[(2, 2), (2, 2), (2, 2), (1, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
else:
ids = [6, 9, 21]
base_net = get_mobilenet_v3('large', norm_layer=norm_layer, norm_kwargs=norm_kwargs)
else:
raise ValueError('The %s is not support.' % backbone_name)
stages = [base_net.features[:ids[0]],
base_net.features[ids[0]:ids[1]],
base_net.features[ids[1]:ids[2]],
]
sizes = [[4, 6, 8, 10, 12, 16], [20, 24, 28, 32], [36, 40, 44, 48],
[56, 64, 72, 80], [88, 96, 104, 112], [124, 136, 148, 160],
[184, 208, 232, 256]]
ratios = [[1.0] for _ in range(7)]
steps = [4, 8, 16, 32, 64, 86, 128]
net = CLRS(stages, sizes, ratios, steps)
return net
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from nets.stn import STN
from nets.rnn_layer import RNNLayer
from nets.backbone.resnet import get_resnet
from nets.backbone.resnext import get_resnext
from nets.backbone.vgg import get_vgg
from nets.backbone.mobilenetv3 import get_mobilenet_v3
class CRNN(nn.HybridBlock):
def __init__(self, stages, hidden_size=256, num_layers=2, dropout=0.1, voc_size=37, birnn=True, use_stn=False,LSTMFlag = True, **kwargs):
super(CRNN, self).__init__(**kwargs)
self.use_stn = use_stn
with self.name_scope():
self.stages = stages
if use_stn:
self.stn = STN()
if LSTMFlag == False:
self.lstm = RNNLayer('lstm', num_layers, hidden_size, dropout=dropout, bidirectional=birnn, layout='NTC')
else:
self.lstm = gluon.rnn.LSTM(hidden_size, num_layers, dropout=dropout, bidirectional=birnn, layout='NTC')
self.dropout = nn.Dropout(dropout)
self.fc = nn.Dense(voc_size, flatten=False)
def hybrid_forward(self, F, x):
if self.use_stn:
x = self.stn(x)
x = self.stages(x)
x = F.transpose(x, axes=(0, 3, 2, 1))
x = F.reshape(x, (0, -3, 0))
x = self.lstm(x)
x = self.dropout(x)
x = self.fc(x)
return x
def export_block(self, prefix, param_path, ctx=mx.cpu()):
if not isinstance(ctx, list):
ctx = [ctx]
data = mx.nd.ones((1, 3, 32, 320), dtype='float32', ctx=ctx[0])
self.load_parameters(param_path)
self.hybridize()
self.collect_params().reset_ctx(ctx)
pred1 = self(data)
self.export(prefix, epoch=0)
def get_crnn(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
if 'resnet' in backbone_name.lower():
if 'v1' in backbone_name.lower():
version = 1
else:
version = 2
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnet(version, num_layers, strides=[(2, 1), (1, 1), (2, 2), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
backbone = base_net.features[:-1]
elif 'resnext' in backbone_name.lower():
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnext(num_layers, strides=[(2, 1), (1, 1), (2, 2), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
backbone = base_net.features[:-1]
elif 'vgg' in backbone_name.lower():
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_vgg(num_layers, strides=[(2, 1), (2, 2), (2, 2), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
backbone = base_net.features[:-4]
elif 'mobilenetv3' in backbone_name.lower():
model_name = int(backbone_name.lower().split('_')[1])
base_net = get_mobilenet_v3(model_name, strides=[(2, 2), (2, 1), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
backbone = base_net.features[:-4]
else:
raise ValueError('Please input right backbone name.')
crnn = CRNN(backbone, **kwargs)
return crnn
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from nets.backbone.resnet import get_resnet
from nets.backbone.mobilenetv3 import get_mobilenet_v3
from nets.backbone.resnext import get_resnext
class DBNet(gluon.HybridBlock):
def __init__(self, stages, inner_channels=256, k=10, use_bias=False, adaptive=True, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
super(DBNet, self).__init__(**kwargs)
self.k = k
self.adaptive = adaptive
with self.name_scope():
self.stages = nn.HybridSequential()
self.ins_proj = nn.HybridSequential()
self.outs = nn.HybridSequential()
for i in range(len(stages)):
self.stages.add(stages[i])
self.ins_proj.add(nn.Conv2D(inner_channels, 1, use_bias=use_bias))
self.outs.add(nn.Conv2D(inner_channels // 4, 3, padding=1, use_bias=use_bias))
self.binarize = nn.HybridSequential()
self.binarize.add(nn.Conv2D(inner_channels // 4, 3, padding=1, use_bias=use_bias))
self.binarize.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
self.binarize.add(nn.Activation('relu'))
self.binarize.add(nn.Conv2DTranspose(inner_channels // 4, 2, 2))
self.binarize.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
self.binarize.add(nn.Activation('relu'))
self.binarize.add(nn.Conv2DTranspose(1, 2, 2))
self.binarize.add(nn.Activation('sigmoid'))
if adaptive:
self.thresh = nn.HybridSequential()
self.thresh.add(nn.Conv2D(inner_channels // 4, 3, padding=1, use_bias=use_bias))
self.thresh.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
self.thresh.add(nn.Activation('relu'))
self.thresh.add(nn.Conv2DTranspose(inner_channels // 4, 2, 2))
self.thresh.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
self.thresh.add(nn.Activation('relu'))
self.thresh.add(nn.Conv2DTranspose(1, 2, 2))
self.thresh.add(nn.Activation('sigmoid'))
def hybrid_forward(self, F, x):
features = []
for stage in self.stages:
x = stage(x)
features.append(x)
for i, block in enumerate(self.ins_proj):
features[i] = block(features[i])
in2, in3, in4, in5 = features
out4 = F.contrib.BilinearResize2D(in5, like=in4, mode='like') + in4
out3 = F.contrib.BilinearResize2D(out4, like=in3, mode='like') + in3
out2 = F.contrib.BilinearResize2D(out3, like=in2, mode='like') + in2
features = [out2, out3, out4, in5]
output = []
for feat, block in zip(features, self.outs):
out = block(feat)
out = F.contrib.BilinearResize2D(out, like=in2, mode='like')
output.append(out)
fuse = F.concat(*output, dim=1)
binary = self.binarize(fuse)
# if not mx.autograd.is_training():
# return binary
if self.adaptive:
temp = F.contrib.BilinearResize2D(binary, like=fuse, mode='like')
fuse = F.concat(fuse, temp, dim=1)
thresh = self.thresh(fuse)
thresh_binary = 1.0 / (1.0 + F.exp(-self.k * (binary - thresh)))
return binary, thresh, thresh_binary
else:
return binary
def export_block(self, prefix, param_path, ctx=mx.cpu()):
if not isinstance(ctx, list):
ctx = [ctx]
data = mx.nd.ones((1, 3, 512, 512), dtype='float32', ctx=ctx[0])
self.load_parameters(param_path)
self.hybridize()
self.collect_params().reset_ctx(ctx)
pred1 = self(data)
self.export(prefix, epoch=0)
def get_db(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
if 'resnet' in backbone_name.lower():
if 'v1' in backbone_name.lower():
version = 1
else:
version = 2
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnet(version, num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
ids = [5, 6, 7, 8]
elif 'resnext' in backbone_name.lower():
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnext(num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
ids = [5, 6, 7, 8]
elif 'mobilenetv3' in backbone_name.lower():
model_name = int(backbone_name.lower().split('_')[1])
ids = [4, 6, 12, 17] if model_name == 'small' else [6, 9, 15, 21]
base_net = get_mobilenet_v3(model_name, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
else:
raise ValueError('The %s is not support.' % backbone_name)
stages = [base_net.features[:ids[0]],
base_net.features[ids[0]:ids[1]],
base_net.features[ids[1]:ids[2]],
base_net.features[ids[2]:ids[3]],
]
net = DBNet(stages, **kwargs)
return net
from mxnet.gluon import nn
import mxnet as mx
from nets.backbone.resnet import get_resnet
from nets.backbone.resnext import get_resnext
from nets.backbone.mobilenetv3 import get_mobilenet_v3
class EAST(nn.HybridBlock):
def __init__(self, stages, channels=[128, 128, 128, 32], norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
super(EAST, self).__init__(**kwargs)
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
with self.name_scope():
self.stages = nn.HybridSequential()
self.convs = nn.HybridSequential()
for i in range(len(stages)):
self.stages.add(stages[i])
for i in range(3):
self.convs.add(self._make_layers(channels[i]))
self.pred_score = nn.HybridSequential()
self.pred_score.add(nn.Conv2D(channels[-1], 3, 1, 1))
self.pred_score.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
self.pred_score.add(nn.Activation('relu'))
self.pred_score.add(nn.Conv2D(1, 1, 1))
self.pred_score.add(nn.Activation('sigmoid'))
self.pred_geo = nn.HybridSequential()
self.pred_geo.add(nn.Conv2D(channels[-1], 3, 1, 1))
self.pred_geo.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
self.pred_geo.add(nn.Activation('relu'))
self.pred_geo.add(nn.Conv2D(8, 1, 1))
self.pred_geo.add(nn.Activation('sigmoid'))
def _make_layers(self, channel, ksize=3, stride=1, padding=1, act_type='relu'):
layer = nn.HybridSequential()
layer.add(nn.Conv2D(channel, 1, 1))
layer.add(self.norm_layer(**({} if self.norm_kwargs is None else self.norm_kwargs)))
layer.add(nn.Activation(act_type))
layer.add(nn.Conv2D(channel, ksize, stride, padding))
layer.add(self.norm_layer(**({} if self.norm_kwargs is None else self.norm_kwargs)))
layer.add(nn.Activation(act_type))
return layer
def hybrid_forward(self, F, x):
feats = []
for block in self.stages:
x = block(x)
feats.append(x)
feats = feats[::-1]
h = feats[0]
for i in range(3):
h = F.UpSampling(h, scale=2, sample_type='nearest')
h = F.concat(h, feats[i + 1], dim=1)
h = self.convs[i](h)
scores = self.pred_score(h)
geometrys = (self.pred_geo(h) - 0.5) * 2 * 800
return scores, geometrys
def export_block(self, prefix, param_path, ctx=mx.cpu()):
if not isinstance(ctx, list):
ctx = [ctx]
data = mx.nd.ones((1, 3, 512, 512), dtype='float32', ctx=ctx[0])
self.load_parameters(param_path)
self.hybridize()
self.collect_params().reset_ctx(ctx)
pred1 = self(data)
self.export(prefix, epoch=0)
def get_east(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
if 'resnet' in backbone_name.lower():
if 'v1' in backbone_name.lower():
version = 1
else:
version = 2
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnet(version, num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs,use_se=False)
ids = [5, 6, 7, 8]
elif 'resnext' in backbone_name.lower():
num_layers = int(backbone_name.lower().split('_')[1])
base_net = get_resnext(num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
ids = [5, 6, 7, 8]
elif 'mobilenetv3' in backbone_name.lower():
model_name = int(backbone_name.lower().split('_')[1])
ids = [4, 6, 12, 17] if model_name=='small' else [6, 9, 15, 21]
base_net = get_mobilenet_v3(model_name, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
else:
raise ValueError('Input right backbone name.')
stages = [base_net.features[:ids[0]],
base_net.features[ids[0]:ids[1]],
base_net.features[ids[1]:ids[2]],
base_net.features[ids[2]:ids[3]],
]
net = EAST(stages, **kwargs)
return net