MXNET下载训练好的模型

MXNet中有很多训练好的模型,直接在网页上下载一直报错,后来在https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/common/modelzoo.py中找到一个下载代码。修改以后可直接调用用于下载。

import os

import subprocess
import os
import errno

import mxnet as mx

_base_model_url = 'http://data.mxnet.io/models/'
_default_model_info = {
    'imagenet1k-vgg-16': {'symbol':_base_model_url+'imagenet/vgg/vgg16-symbol.json',
                             'params':_base_model_url+'imagenet/vgg/vgg16-0000.params'},
    'imagenet1k-inception-bn': {'symbol':_base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
                             'params':_base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
    'imagenet1k-resnet-18': {'symbol':_base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
                             'params':_base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
    'imagenet1k-resnet-34': {'symbol':_base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
                             'params':_base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
    'imagenet1k-resnet-50': {'symbol':_base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
                             'params':_base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
    'imagenet1k-resnet-101': {'symbol':_base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
                             'params':_base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
    'imagenet1k-resnet-152': {'symbol':_base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
                             'params':_base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
    'imagenet1k-resnext-50': {'symbol':_base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
                             'params':_base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
    'imagenet1k-resnext-101': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
                             'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
    'imagenet1k-resnext-101-64x4d': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json',
                                     'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'},
    'imagenet11k-resnet-152': {'symbol':_base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
                             'params':_base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
    'imagenet11k-place365ch-resnet-152': {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
                                          'params':_base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
    'imagenet11k-place365ch-resnet-50': {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
                                         'params':_base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
}


def download_file(url, local_fname=None, force_write=False):
    # requests is not default installed
    import requests
    if local_fname is None:
        local_fname = url.split('/')[-1]
    if not force_write and os.path.exists(local_fname):
        return local_fname

    dir_name = os.path.dirname(local_fname)

    if dir_name != "":
        if not os.path.exists(dir_name):
            try: # try to create the directory if it doesn't exists
                os.makedirs(dir_name)
            except OSError as exc:
                if exc.errno != errno.EEXIST:
                    raise

    r = requests.get(url, stream=True)
    assert r.status_code == 200, "failed to open %s" % url
    with open(local_fname, 'wb') as f:
        for chunk in r.iter_content(chunk_size=1024):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)
    return local_fname

def download_model(model_name, dst_dir='./', meta_info=None):
    if meta_info is None:
        meta_info = _default_model_info
    meta_info = dict(meta_info)
    if model_name not in meta_info:
        return (None, 0)
    if not os.path.isdir(dst_dir):
        os.mkdir(dst_dir)
    meta = dict(meta_info[model_name])
    assert 'symbol' in meta, "missing symbol url"
    model_name = os.path.join(dst_dir, model_name)
    
    download_file(meta['symbol'], model_name+'-symbol.json')
    assert 'params' in meta, "mssing parameter file url"
    download_file(meta['params'], model_name+'-0000.params')
    return (model_name, 0)

#download pretrained models
dir_path = 'E:/Spyder/'
(prefix, epoch) = download_model(model_name = 'imagenet1k-vgg-16', 
    dst_dir = os.path.join(dir_path, 'model'), meta_info = None)

注,上面的链接已经失效,但仍可使用代码下载)

还可通过mx.gluon.model_zoo.vision下载模型和参数,此为官方方法,以VGG16-bn为例。

加载模型,其中,pretrained=True表示同时下载对应的参数,root为下载参数存储的路径。

 vgg16_bn = mx.gluon.model_zoo.vision.vgg16_bn(pretrained= True, root='E:\\Spyder\\model',ctx=mx.cpu() )

此时,

MXNET下载训练好的模型_第1张图片

 如果已经有params文件,可设置pretrained=False,然后通过load_parameters加载参数。

vgg16_bn.load_parameters('E:\\Spyder\\model\\vgg16_bn.params')

下面将vgg16_bn导为json文件。首先,使用hybridize去activate,才能导出成json格式的模型文件

vgg16_bn.hybridize()

然后,在导出模型文件和参数文件之前,需要进行一次前向传播:

img = mx.image.imread('./test_img/1.jpg')
transformed_img = gluoncv.data.transforms.presets.imagenet.transform_eval(img, resize_short=224)
output = net(transformed_img)

其实并不需要真实图像,使用全0或者全1的tensor也可以:

from mxnet import nd
data = nd.zeros(shape=(1,3,224,224))
output = vgg16_bn(data)

最后,使用export在指定路径导出json和params文件: 

vgg16_bn.export('E:\\Spyder\\model\\vgg16_bn', epoch=0)

你可能感兴趣的:(mxnet,深度学习)