基于Mxnet实现OCR网络结构【网络结构源码】

文章目录

  • 前言
  • AttModel
  • CLRS
  • CRNN
  • dbNet
  • EAST


前言

本文基于Mxnet实现OCR的网络结构


AttModel


class AttModel(nn.HybridBlock):
    def __init__(self, encoder, decoder, start_symbol=0, end_symbol=1, **kwargs):
        super(AttModel, self).__init__(**kwargs)
        self.start_symbol = start_symbol
        self.end_symbol = end_symbol
        with self.name_scope():
            self.encoder = encoder
            self.decoder = decoder

    def hybrid_forward(self, F, x, mask, targ_input):
        if isinstance(x, mx.ndarray.NDArray):
            batch_size = x.shape[0]
            state = self.begin_state(func=mx.ndarray.zeros, ctx=x.context, batch_size=batch_size, dtype=x.dtype)
        else:
            state = self.begin_state(func=mx.symbol.zeros)

        en_out, en_proj, mask = self.encoder(x, mask)
        states = [en_out, en_proj, mask, state]
        tag_input = F.transpose(targ_input, axes=(1, 0)).expand_dims(axis=-1)

        def train_func(out, states):
            outs = self.decoder(out, states)
            return outs[0], outs[1:]

        if autograd.is_training():
            outputs, states = F.contrib.foreach(train_func, tag_input, states)
            outputs = F.reshape(outputs, shape=(0, -3, 0))
            outputs = F.transpose(outputs, axes=(1, 0, 2))
            return outputs

        def test_func(inp, states):
            outs = self.decoder(*states)
            pred = F.softmax(outs[0], axis=-1)
            pred = F.argmax(pred, axis=-1)
            states = [pred * inp] + list(outs[1:])
            return pred, states

        first_input = F.slice_axis(tag_input, axis=0, begin=0, end=1)
        first_input = F.squeeze(first_input, axis=0) * self.start_symbol
        states = [first_input] + states
        outputs, states = F.contrib.foreach(test_func, tag_input, states)
        outputs = F.reshape(outputs, (0, -1))
        outputs = F.transpose(outputs, axes=(1, 0))
        return outputs

    def __call__(self, data, mask, targ_input):
        return super(AttModel, self).__call__(data, mask, targ_input)

    def export_block(self, prefix, param_path, ctx=mx.cpu()):
        if not isinstance(ctx, list):
            ctx = [ctx]
        data = mx.nd.ones((1, 3, 32, 128), ctx=ctx[0])
        mask = mx.nd.ones((1, 1, 1, 16), ctx=ctx[0])
        targ_input = mx.nd.ones((1, 10), ctx=ctx[0])
        self.hybridize()
        # self.load_parameters(param_path)
        self.initialize()
        self.collect_params().reset_ctx(ctx)
        outs = self.__call__(data, mask, targ_input)
        self.export(prefix)

    def begin_state(self, *args, **kwargs):
        return self.decoder.lstm.begin_state(*args, **kwargs)



CLRS

import mxnet as mx
from mxnet import autograd
from mxnet import gluon
from mxnet.gluon import nn
from nets.anchor import CLRSAnchorGenerator
from nets.backbone.resnet import get_resnet
from nets.backbone.mobilenetv3 import get_mobilenet_v3
from nets.backbone.resnext import get_resnext

class MultiPerClassDecoder(gluon.HybridBlock):
    def __init__(self, num_class, axis=-1, thresh=0.01):
        super(MultiPerClassDecoder, self).__init__()
        self._fg_class = num_class - 1
        self._axis = axis
        self._thresh = thresh

    def hybrid_forward(self, F, x):
        scores = x.slice_axis(axis=self._axis, begin=1, end=None)  # b x N x fg_class
        template = F.zeros_like(x.slice_axis(axis=-1, begin=0, end=1))
        cls_id = F.broadcast_add(template, F.reshape(F.arange(self._fg_class), shape=(1, 1, self._fg_class)))
        mask = scores > self._thresh
        cls_id = F.where(mask, cls_id, F.ones_like(cls_id) * -1)
        scores = F.where(mask, scores, F.zeros_like(scores))
        return cls_id, scores

class BBoxCornerToCenter(gluon.HybridBlock):
    def __init__(self, axis=-1, split=False):
        super(BBoxCornerToCenter, self).__init__()
        self._split = split
        self._axis = axis

    def hybrid_forward(self, F, x):
        xmin, ymin, xmax, ymax = F.split(x, axis=self._axis, num_outputs=4)
        width = xmax - xmin
        height = ymax - ymin
        x = xmin + width * 0.5
        y = ymin + height * 0.5
        if not self._split:
            return F.concat(x, y, width, height, dim=self._axis)
        else:
            return x, y, width, height

class NormalizedBoxCenterDecoder(gluon.HybridBlock):
    def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), convert_anchor=False, clip=None):
        super(NormalizedBoxCenterDecoder, self).__init__()
        assert len(stds) == 4, "Box Encoder requires 4 std values."
        self._stds = stds
        self._clip = clip
        if convert_anchor:
            self.corner_to_center = BBoxCornerToCenter(split=True)
        else:
            self.corner_to_center = None
        self._format = 'corner' if convert_anchor else 'center'

    def hybrid_forward(self, F, x, anchors):
        if 'box_decode' in F.contrib.__dict__:
            x, anchors = F.amp_multicast(x, anchors, num_outputs=2, cast_narrow=True)
            if self._clip is None:
                self._clip = -1
            return F.contrib.box_decode(x, anchors, self._stds[0], self._stds[1], self._stds[2], self._stds[3], clip=self._clip, format=self._format)
        if self.corner_to_center is not None:
            a = self.corner_to_center(anchors)
        else:
            a = anchors.split(axis=-1, num_outputs=4)
        p = F.split(x, axis=-1, num_outputs=4)
        ox = F.broadcast_add(F.broadcast_mul(p[0] * self._stds[0], a[2]), a[0])
        oy = F.broadcast_add(F.broadcast_mul(p[1] * self._stds[1], a[3]), a[1])
        dw = p[2] * self._stds[2]
        dh = p[3] * self._stds[3]
        if self._clip:
            dw = F.minimum(dw, self._clip)
            dh = F.minimum(dh, self._clip)
        dw = F.exp(dw)
        dh = F.exp(dh)
        ow = F.broadcast_mul(dw, a[2]) * 0.5
        oh = F.broadcast_mul(dh, a[3]) * 0.5
        return F.concat(ox - ow, oy - oh, ox + ow, oy + oh, dim=-1)

class DM(nn.HybridBlock):
    def __init__(self, channels=128, ksize=2, strides=2, pad=0, **kwargs):
        super(DM, self).__init__(**kwargs)
        with self.name_scope():
            self.deconv = nn.HybridSequential()
            self.deconv.add(nn.Conv2DTranspose(channels, ksize, strides))
            self.deconv.add(nn.Conv2D(channels, 3, 1, 1))
            self.deconv.add(nn.BatchNorm())

            self.conv = nn.HybridSequential()
            self.conv.add(nn.Conv2D(channels, 3, 1, 1))
            self.conv.add(nn.BatchNorm())
            self.conv.add(nn.Activation('relu'))
            self.conv.add(nn.Conv2D(channels, 3, 1, 1))
            self.conv.add(nn.BatchNorm())

    def hybrid_forward(self, F, x1, x2):
        x1 = self.deconv(x1)
        x2 = self.conv(x2)
        return F.relu(x1 * x2)

class PM(nn.HybridBlock):
    def __init__(self, channels=256, k=4, num_classes=4, **kwargs):
        super(PM, self).__init__(**kwargs)
        with self.name_scope():
            self.skip = nn.HybridSequential()
            self.skip.add(nn.Conv2D(channels, 1, 1))
            self.bone = nn.HybridSequential()
            self.bone.add(nn.Conv2D(channels, 1, 1))
            self.bone.add(nn.Conv2D(channels, 1, 1))
            self.bone.add(nn.Conv2D(channels, 1, 1))
            self.conf = nn.Conv2D(k * (num_classes + 1), 3, 1, 1)
            self.loc = nn.Conv2D(k * 4, 3, 1, 1)

    def hybrid_forward(self, F, x):
        x1 = self.skip(x)
        x2 = self.bone(x)
        x = F.relu(x1 + x2)
        score = self.conf(x)
        offset = self.loc(x)
        return score, offset

class SM(nn.HybridBlock):
    def __init__(self, channels, n_scale=2, **kwargs):
        super(SM, self).__init__(**kwargs)
        self.n_scale = n_scale
        with self.name_scope():
            self.skip = nn.HybridSequential()
            self.skip.add(nn.Conv2D(channels, 1, 1))
            self.skip.add(nn.BatchNorm())
            self.bone = nn.HybridSequential()
            self.bone.add(nn.Conv2D(channels, 1, 1))
            self.bone.add(nn.BatchNorm())
            self.bone.add(nn.Activation('relu'))
            self.bone.add(nn.Conv2D(channels, 1, 1))
            self.bone.add(nn.BatchNorm())
            self.bone.add(nn.Activation('relu'))
            self.bone.add(nn.Conv2D(channels, 1, 1))
            self.bone.add(nn.BatchNorm())

    def hybrid_forward(self, F, x):
        x1 = self.skip(x)
        x2 = self.bone(x)
        x = F.relu(x1 + x2)
        if self.n_scale > 1:
            x = F.UpSampling(x, scale=self.n_scale, sample_type='nearest')
        return x

class SegPred(nn.HybridBlock):
    def __init__(self, channels, **kwargs):
        super(SegPred, self).__init__(**kwargs)
        with self.name_scope():
            self.sms = nn.HybridSequential()
            self.sms.add(SM(channels, 16))
            self.sms.add(SM(channels, 8))
            self.sms.add(SM(channels, 4))
            self.sms.add(SM(channels, 2))
            self.sms.add(SM(channels, 1))

            self.tail = nn.HybridSequential()
            self.tail.add(nn.Conv2D(channels, 1, 1))
            self.tail.add(nn.BatchNorm())
            self.tail.add(nn.Activation('relu'))
            self.tail.add(nn.Conv2DTranspose(channels, 2, 2))
            self.tail.add(nn.Conv2D(channels, 3, 1, 1))
            self.tail.add(nn.BatchNorm())
            self.tail.add(nn.Activation('relu'))
            self.tail.add(nn.Conv2DTranspose(4, 2, 2))
            self.tail.add(nn.Activation('sigmoid'))

    def hybrid_forward(self, F, *xs):
        feats = []
        for x, block in zip(xs, self.sms):
            feat = block(x)
            feats.append(feat)
        fuse_feat = F.relu(F.add_n(*feats))
        return self.tail(fuse_feat)

class CLRS(nn.HybridBlock):
    def __init__(self, stages, sizes, ratios, steps, dm_channels=256,
                 pm_channels=256, sm_channels=32,
                 stds=(0.1, 0.1, 0.2, 0.2), nms_thresh=0.45,
                 nms_topk=1000, post_nms=400,
                 anchor_alloc_size=256, **kwargs):
        super(CLRS, self).__init__(**kwargs)

        self.nms_thresh = nms_thresh
        self.nms_topk = nms_topk
        self.post_nms = post_nms

        with self.name_scope():
            self.stages = nn.HybridSequential()
            for i in range(len(stages)):
                self.stages.add(stages[i])
            # extra layers
            self.extras = nn.HybridSequential()
            self.extras.add(self._extra_layer(256, 512))
            self.extras.add(self._extra_layer(128, 256))
            self.extras.add(self._extra_layer(128, 256, strides=1))
            self.extras.add(self._extra_layer(128, 256, strides=1))
            self.dms = nn.HybridSequential()
            for i in range(6):
                strides = 2 if i > 1 else 1
                ksize = 2 if strides == 2 else 3
                self.dms.add(DM(dm_channels, ksize, strides=strides, pad=ksize - 2))
            self.pms = nn.HybridSequential()
            self.anchor_generators = nn.HybridSequential()
            asz = anchor_alloc_size
            for i, (s, r, st) in enumerate(zip(sizes, ratios, steps)):
                self.pms.add(PM(pm_channels, len(s)))
                anchor_generator = CLRSAnchorGenerator(i, (512, 512), s, r, st, alloc_size=(asz, asz))
                self.anchor_generators.add(anchor_generator)
                asz = max(asz // 2, 16)
            self.seg_pred = SegPred(sm_channels)
            self.bbox_decoder = NormalizedBoxCenterDecoder(stds)
            self.cls_decoder = MultiPerClassDecoder(4 + 1, thresh=0.01)

    def _extra_layer(self, in_channels, out_channels, strides=2, norm_layer=nn.BatchNorm, norm_kwargs=None):
        layer = nn.HybridSequential()
        layer.add(nn.Conv2D(in_channels, 1, 1))
        layer.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
        layer.add(nn.Activation('relu'))
        layer.add(nn.Conv2D(out_channels, 3, strides, strides - 1))
        layer.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
        layer.add(nn.Activation('relu'))
        return layer

    def set_nms(self, nms_thresh=0.45, nms_topk=400, post_nms=100):
        self._clear_cached_op()
        self.nms_thresh = nms_thresh
        self.nms_topk = nms_topk
        self.post_nms = post_nms

    def hybrid_forward(self, F, x):
        feats = []
        for block in self.stages:
            x = block(x)
            feats.append(x)

        for block in self.extras:
            x = block(x)
            feats.append(x)

        F11 = feats[-1]
        F10 = self.dms[0](F11, feats[-2])
        F9 = self.dms[1](F10, feats[-3])
        F8 = self.dms[2](F9, feats[-4])
        F7 = self.dms[3](F8, feats[-5])
        F4 = self.dms[4](F7, feats[-6])
        F3 = self.dms[5](F4, feats[-7])

        feats = [F3, F4, F7, F8, F9, F10, F11]
        box_preds, cls_preds = [], []
        for feat, block in zip(feats, self.pms):
            pred = block(feat)
            cls_preds.append(F.flatten(F.transpose(pred[0], (0, 2, 3, 1))))
            box_preds.append(F.flatten(F.transpose(pred[1], (0, 2, 3, 1))))
        cls_preds = F.concat(*cls_preds, dim=1).reshape((0, -1, 5))
        box_preds = F.concat(*box_preds, dim=1).reshape((0, -1, 4))
        anchors = [F.reshape(ag(feat), shape=(1, -1))
                   for feat, ag in zip(feats, self.anchor_generators)]
        anchors = F.concat(*anchors, dim=1).reshape((1, -1, 4))

        seg_maps = self.seg_pred(F9, F8, F7, F4, F3)
        if autograd.is_training():
            return [cls_preds, box_preds, anchors, seg_maps]

        bboxes = self.bbox_decoder(box_preds, anchors)
        cls_ids, scores = self.cls_decoder(F.softmax(cls_preds, axis=-1))
        results = []
        for i in range(4):
            cls_id = cls_ids.slice_axis(axis=-1, begin=i, end=i + 1)
            score = scores.slice_axis(axis=-1, begin=i, end=i + 1)
            # per class results
            per_result = F.concat(*[cls_id, score, bboxes], dim=-1)
            results.append(per_result)
        result = F.concat(*results, dim=1)
        if self.nms_thresh > 0 and self.nms_thresh < 1:
            result = F.contrib.box_nms(
                result, overlap_thresh=self.nms_thresh, topk=self.nms_topk, valid_thresh=0.01,
                id_index=0, score_index=1, coord_start=2, force_suppress=False)
            if self.post_nms > 0:
                result = result.slice_axis(axis=1, begin=0, end=self.post_nms)
        ids = F.slice_axis(result, axis=2, begin=0, end=1)
        scores = F.slice_axis(result, axis=2, begin=1, end=2)
        bboxes = F.slice_axis(result, axis=2, begin=2, end=6)
        return ids, scores, bboxes, seg_maps

    def export_block(self, prefix, param_path, ctx=mx.cpu()):
        if not isinstance(ctx, list):
            ctx = [ctx]
        data = mx.nd.ones((1, 3, 512, 512), dtype='float32', ctx=ctx[0])
        self.load_parameters(param_path)
        self.set_nms(0.45, 1000, 400)
        self.hybridize()
        self.collect_params().reset_ctx(ctx)
        pred1 = self(data)
        self.export(prefix, epoch=0)

def get_clrs(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None):
    if 'resnet' in backbone_name.lower():
        if 'v1' in backbone_name.lower():
            version = 1
        else:
            version = 2
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnet(version, num_layers, strides=[(2, 2), (1, 1), (2, 2), (2, 2), (1, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        ids = [5, 6, 8]
    elif 'resnext' in backbone_name.lower():
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnext(num_layers, strides=[(2, 2), (1, 1), (2, 2), (2, 2), (1, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        ids = [5, 6, 8]
    elif 'mobilenetv3' in backbone_name.lower():
        model_name = int(backbone_name.lower().split('_')[1])
        ids = [4, 6, 17] if model_name == 'small' else [6, 9, 21]
        if model_name == 'small':
            base_net = get_mobilenet_v3('small', strides=[(2, 2), (2, 2), (2, 2), (1, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        else:
            ids = [6, 9, 21]
            base_net = get_mobilenet_v3('large', norm_layer=norm_layer, norm_kwargs=norm_kwargs)
    else:
        raise ValueError('The %s is not support.' % backbone_name)

    stages = [base_net.features[:ids[0]],
              base_net.features[ids[0]:ids[1]],
              base_net.features[ids[1]:ids[2]],
              ]
    sizes = [[4, 6, 8, 10, 12, 16], [20, 24, 28, 32], [36, 40, 44, 48],
             [56, 64, 72, 80], [88, 96, 104, 112], [124, 136, 148, 160],
             [184, 208, 232, 256]]
    ratios = [[1.0] for _ in range(7)]
    steps = [4, 8, 16, 32, 64, 86, 128]
    net = CLRS(stages, sizes, ratios, steps)
    return net


CRNN

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from nets.stn import STN
from nets.rnn_layer import RNNLayer
from nets.backbone.resnet import get_resnet
from nets.backbone.resnext import get_resnext
from nets.backbone.vgg import get_vgg
from nets.backbone.mobilenetv3 import get_mobilenet_v3

class CRNN(nn.HybridBlock):
    def __init__(self, stages, hidden_size=256, num_layers=2, dropout=0.1, voc_size=37, birnn=True, use_stn=False,LSTMFlag = True, **kwargs):
        super(CRNN, self).__init__(**kwargs)

        self.use_stn = use_stn
        with self.name_scope():
            self.stages = stages
            if use_stn:
                self.stn = STN()
            if LSTMFlag == False:
                self.lstm = RNNLayer('lstm', num_layers, hidden_size, dropout=dropout, bidirectional=birnn, layout='NTC')
            else:
                self.lstm = gluon.rnn.LSTM(hidden_size, num_layers, dropout=dropout, bidirectional=birnn, layout='NTC')

            self.dropout = nn.Dropout(dropout)
            self.fc = nn.Dense(voc_size, flatten=False)

    def hybrid_forward(self, F, x):
        if self.use_stn:
            x = self.stn(x)
        x = self.stages(x)
        x = F.transpose(x, axes=(0, 3, 2, 1))
        x = F.reshape(x, (0, -3, 0))
        x = self.lstm(x)
        x = self.dropout(x)
        x = self.fc(x)
        return x

    def export_block(self, prefix, param_path, ctx=mx.cpu()):
        if not isinstance(ctx, list):
            ctx = [ctx]
        data = mx.nd.ones((1, 3, 32, 320), dtype='float32', ctx=ctx[0])
        self.load_parameters(param_path)
        self.hybridize()
        self.collect_params().reset_ctx(ctx)
        pred1 = self(data)
        self.export(prefix, epoch=0)

def get_crnn(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
    if 'resnet' in backbone_name.lower():
        if 'v1' in backbone_name.lower():
            version = 1
        else:
            version = 2
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnet(version, num_layers, strides=[(2, 1), (1, 1), (2, 2), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        backbone = base_net.features[:-1]
    elif 'resnext' in backbone_name.lower():
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnext(num_layers, strides=[(2, 1), (1, 1), (2, 2), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        backbone = base_net.features[:-1]
    elif 'vgg' in backbone_name.lower():
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_vgg(num_layers, strides=[(2, 1), (2, 2), (2, 2), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        backbone = base_net.features[:-4]
    elif 'mobilenetv3' in backbone_name.lower():
        model_name = int(backbone_name.lower().split('_')[1])
        base_net = get_mobilenet_v3(model_name, strides=[(2, 2), (2, 1), (2, 1), (2, 1)], norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        backbone = base_net.features[:-4]
    else:
        raise ValueError('Please input right backbone name.')
    crnn = CRNN(backbone, **kwargs)
    return crnn


dbNet

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from nets.backbone.resnet import get_resnet
from nets.backbone.mobilenetv3 import get_mobilenet_v3
from nets.backbone.resnext import get_resnext

class DBNet(gluon.HybridBlock):
    def __init__(self, stages, inner_channels=256, k=10, use_bias=False, adaptive=True, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
        super(DBNet, self).__init__(**kwargs)

        self.k = k
        self.adaptive = adaptive
        with self.name_scope():
            self.stages = nn.HybridSequential()
            self.ins_proj = nn.HybridSequential()
            self.outs = nn.HybridSequential()
            for i in range(len(stages)):
                self.stages.add(stages[i])
                self.ins_proj.add(nn.Conv2D(inner_channels, 1, use_bias=use_bias))
                self.outs.add(nn.Conv2D(inner_channels // 4, 3, padding=1, use_bias=use_bias))

            self.binarize = nn.HybridSequential()

            self.binarize.add(nn.Conv2D(inner_channels // 4, 3, padding=1, use_bias=use_bias))
            self.binarize.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
            self.binarize.add(nn.Activation('relu'))
            self.binarize.add(nn.Conv2DTranspose(inner_channels // 4, 2, 2))
            self.binarize.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
            self.binarize.add(nn.Activation('relu'))
            self.binarize.add(nn.Conv2DTranspose(1, 2, 2))
            self.binarize.add(nn.Activation('sigmoid'))
            if adaptive:
                self.thresh = nn.HybridSequential()
                self.thresh.add(nn.Conv2D(inner_channels // 4, 3, padding=1, use_bias=use_bias))
                self.thresh.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
                self.thresh.add(nn.Activation('relu'))
                self.thresh.add(nn.Conv2DTranspose(inner_channels // 4, 2, 2))
                self.thresh.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
                self.thresh.add(nn.Activation('relu'))
                self.thresh.add(nn.Conv2DTranspose(1, 2, 2))
                self.thresh.add(nn.Activation('sigmoid'))

    def hybrid_forward(self, F, x):
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)

        for i, block in enumerate(self.ins_proj):
            features[i] = block(features[i])
        in2, in3, in4, in5 = features

        out4 = F.contrib.BilinearResize2D(in5, like=in4, mode='like') + in4
        out3 = F.contrib.BilinearResize2D(out4, like=in3, mode='like') + in3
        out2 = F.contrib.BilinearResize2D(out3, like=in2, mode='like') + in2

        features = [out2, out3, out4, in5]
        output = []
        for feat, block in zip(features, self.outs):
            out = block(feat)
            out = F.contrib.BilinearResize2D(out, like=in2, mode='like')
            output.append(out)

        fuse = F.concat(*output, dim=1)
        binary = self.binarize(fuse)
        # if not mx.autograd.is_training():
        #     return binary
        if self.adaptive:
            temp = F.contrib.BilinearResize2D(binary, like=fuse, mode='like')
            fuse = F.concat(fuse, temp, dim=1)
            thresh = self.thresh(fuse)
            thresh_binary = 1.0 / (1.0 + F.exp(-self.k * (binary - thresh)))
            return binary, thresh, thresh_binary
        else:
            return binary

    def export_block(self, prefix, param_path, ctx=mx.cpu()):
        if not isinstance(ctx, list):
            ctx = [ctx]
        data = mx.nd.ones((1, 3, 512, 512), dtype='float32', ctx=ctx[0])
        self.load_parameters(param_path)
        self.hybridize()
        self.collect_params().reset_ctx(ctx)
        pred1 = self(data)
        self.export(prefix, epoch=0)

def get_db(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
    if 'resnet' in backbone_name.lower():
        if 'v1' in backbone_name.lower():
            version = 1
        else:
            version = 2
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnet(version, num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        ids = [5, 6, 7, 8]
    elif 'resnext' in backbone_name.lower():
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnext(num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        ids = [5, 6, 7, 8]
    elif 'mobilenetv3' in backbone_name.lower():
        model_name = int(backbone_name.lower().split('_')[1])
        ids = [4, 6, 12, 17] if model_name == 'small' else [6, 9, 15, 21]
        base_net = get_mobilenet_v3(model_name, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
    else:
        raise ValueError('The %s is not support.' % backbone_name)
    stages = [base_net.features[:ids[0]],
              base_net.features[ids[0]:ids[1]],
              base_net.features[ids[1]:ids[2]],
              base_net.features[ids[2]:ids[3]],
              ]
    net = DBNet(stages, **kwargs)
    return net



EAST

from mxnet.gluon import nn
import mxnet as mx
from nets.backbone.resnet import get_resnet
from nets.backbone.resnext import get_resnext
from nets.backbone.mobilenetv3 import get_mobilenet_v3

class EAST(nn.HybridBlock):
    def __init__(self, stages, channels=[128, 128, 128, 32], norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):

        super(EAST, self).__init__(**kwargs)
        self.norm_layer = norm_layer
        self.norm_kwargs = norm_kwargs

        with self.name_scope():
            self.stages = nn.HybridSequential()
            self.convs = nn.HybridSequential()
            for i in range(len(stages)):
                self.stages.add(stages[i])
            for i in range(3):
                self.convs.add(self._make_layers(channels[i]))

            self.pred_score = nn.HybridSequential()
            self.pred_score.add(nn.Conv2D(channels[-1], 3, 1, 1))
            self.pred_score.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
            self.pred_score.add(nn.Activation('relu'))
            self.pred_score.add(nn.Conv2D(1, 1, 1))
            self.pred_score.add(nn.Activation('sigmoid'))

            self.pred_geo = nn.HybridSequential()
            self.pred_geo.add(nn.Conv2D(channels[-1], 3, 1, 1))
            self.pred_geo.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs)))
            self.pred_geo.add(nn.Activation('relu'))
            self.pred_geo.add(nn.Conv2D(8, 1, 1))
            self.pred_geo.add(nn.Activation('sigmoid'))

    def _make_layers(self, channel, ksize=3, stride=1, padding=1, act_type='relu'):
        layer = nn.HybridSequential()
        layer.add(nn.Conv2D(channel, 1, 1))
        layer.add(self.norm_layer(**({} if self.norm_kwargs is None else self.norm_kwargs)))
        layer.add(nn.Activation(act_type))
        layer.add(nn.Conv2D(channel, ksize, stride, padding))
        layer.add(self.norm_layer(**({} if self.norm_kwargs is None else self.norm_kwargs)))
        layer.add(nn.Activation(act_type))
        return layer

    def hybrid_forward(self, F, x):
        feats = []
        for block in self.stages:
            x = block(x)
            feats.append(x)
        feats = feats[::-1]
        h = feats[0]
        for i in range(3):
            h = F.UpSampling(h, scale=2, sample_type='nearest')
            h = F.concat(h, feats[i + 1], dim=1)
            h = self.convs[i](h)
        scores = self.pred_score(h)
        geometrys = (self.pred_geo(h) - 0.5) * 2 * 800
        return scores, geometrys

    def export_block(self, prefix, param_path, ctx=mx.cpu()):
        if not isinstance(ctx, list):
            ctx = [ctx]
        data = mx.nd.ones((1, 3, 512, 512), dtype='float32', ctx=ctx[0])
        self.load_parameters(param_path)
        self.hybridize()
        self.collect_params().reset_ctx(ctx)
        pred1 = self(data)
        self.export(prefix, epoch=0)

def get_east(backbone_name, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
    if 'resnet' in backbone_name.lower():
        if 'v1' in backbone_name.lower():
            version = 1
        else:
            version = 2
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnet(version, num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs,use_se=False)
        ids = [5, 6, 7, 8]
    elif 'resnext' in backbone_name.lower():
        num_layers = int(backbone_name.lower().split('_')[1])
        base_net = get_resnext(num_layers, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        ids = [5, 6, 7, 8]
    elif 'mobilenetv3' in backbone_name.lower():
        model_name = int(backbone_name.lower().split('_')[1])
        ids = [4, 6, 12, 17] if model_name=='small' else [6, 9, 15, 21]
        base_net = get_mobilenet_v3(model_name, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
    else:
        raise ValueError('Input right backbone name.')
    stages = [base_net.features[:ids[0]],
              base_net.features[ids[0]:ids[1]],
              base_net.features[ids[1]:ids[2]],
              base_net.features[ids[2]:ids[3]],
              ]
    net = EAST(stages, **kwargs)
    return net

你可能感兴趣的:(深度学习-mxnet,mxnet,python,深度学习)