MXNet中有很多训练好的模型,直接在网页上下载一直报错,后来在https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/common/modelzoo.py中找到一个下载代码。修改以后可直接调用用于下载。
import os
import subprocess
import os
import errno
import mxnet as mx
_base_model_url = 'http://data.mxnet.io/models/'
_default_model_info = {
'imagenet1k-vgg-16': {'symbol':_base_model_url+'imagenet/vgg/vgg16-symbol.json',
'params':_base_model_url+'imagenet/vgg/vgg16-0000.params'},
'imagenet1k-inception-bn': {'symbol':_base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
'params':_base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
'imagenet1k-resnet-18': {'symbol':_base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
'params':_base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
'imagenet1k-resnet-34': {'symbol':_base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
'params':_base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
'imagenet1k-resnet-50': {'symbol':_base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
'params':_base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
'imagenet1k-resnet-101': {'symbol':_base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
'params':_base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
'imagenet1k-resnet-152': {'symbol':_base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
'params':_base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
'imagenet1k-resnext-50': {'symbol':_base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
'params':_base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
'imagenet1k-resnext-101': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
'imagenet1k-resnext-101-64x4d': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json',
'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'},
'imagenet11k-resnet-152': {'symbol':_base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
'params':_base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
'imagenet11k-place365ch-resnet-152': {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
'params':_base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
'imagenet11k-place365ch-resnet-50': {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
'params':_base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
}
def download_file(url, local_fname=None, force_write=False):
# requests is not default installed
import requests
if local_fname is None:
local_fname = url.split('/')[-1]
if not force_write and os.path.exists(local_fname):
return local_fname
dir_name = os.path.dirname(local_fname)
if dir_name != "":
if not os.path.exists(dir_name):
try: # try to create the directory if it doesn't exists
os.makedirs(dir_name)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
r = requests.get(url, stream=True)
assert r.status_code == 200, "failed to open %s" % url
with open(local_fname, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
return local_fname
def download_model(model_name, dst_dir='./', meta_info=None):
if meta_info is None:
meta_info = _default_model_info
meta_info = dict(meta_info)
if model_name not in meta_info:
return (None, 0)
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
meta = dict(meta_info[model_name])
assert 'symbol' in meta, "missing symbol url"
model_name = os.path.join(dst_dir, model_name)
download_file(meta['symbol'], model_name+'-symbol.json')
assert 'params' in meta, "mssing parameter file url"
download_file(meta['params'], model_name+'-0000.params')
return (model_name, 0)
#download pretrained models
dir_path = 'E:/Spyder/'
(prefix, epoch) = download_model(model_name = 'imagenet1k-vgg-16',
dst_dir = os.path.join(dir_path, 'model'), meta_info = None)
(注,上面的链接已经失效,但仍可使用代码下载)
还可通过mx.gluon.model_zoo.vision下载模型和参数,此为官方方法,以VGG16-bn为例。
加载模型,其中,pretrained=True表示同时下载对应的参数,root为下载参数存储的路径。
vgg16_bn = mx.gluon.model_zoo.vision.vgg16_bn(pretrained= True, root='E:\\Spyder\\model',ctx=mx.cpu() )
此时,
如果已经有params文件,可设置pretrained=False,然后通过load_parameters加载参数。
vgg16_bn.load_parameters('E:\\Spyder\\model\\vgg16_bn.params')
下面将vgg16_bn导为json文件。首先,使用hybridize去activate,才能导出成json格式的模型文件
vgg16_bn.hybridize()
然后,在导出模型文件和参数文件之前,需要进行一次前向传播:
img = mx.image.imread('./test_img/1.jpg')
transformed_img = gluoncv.data.transforms.presets.imagenet.transform_eval(img, resize_short=224)
output = net(transformed_img)
其实并不需要真实图像,使用全0或者全1的tensor也可以:
from mxnet import nd
data = nd.zeros(shape=(1,3,224,224))
output = vgg16_bn(data)
最后,使用export在指定路径导出json和params文件:
vgg16_bn.export('E:\\Spyder\\model\\vgg16_bn', epoch=0)