Gluon是MXNet的动态图接口;Gluon学习了Keras,Chainer,和Pytorch的优点,并加以改进。接口更简单,且支持动态图(Imperative)编程。相比TF,Caffe2等静态图(Symbolic)框架更加灵活易用。同时Gluon还继承了MXNet速度快,省显存,并行效率高的优点,并支持静、动态图混用,比Pytorch更快。——转自解浚源知乎
题目中提及的.json、.params模型如下所示:
mxnet版本高于1.2.1可以使用如下方法:
net = gluon.nn.SymbolBlock.imports('resnet18-symbol,['data'],param_file='resnet18-0000.params',ctx=mx.gpu())
以前的版本可以使用:
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18', 0)
net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
# Set the params
net_params = net.collect_params()
for param in arg_params:
if param in net_params:
net_params[param]._load_init(arg_params[param], ctx=ctx)
for param in aux_params:
if param in net_params:
net_params[param]._load_init(aux_params[param], ctx=ctx)
net_params是ParameterDict类型,也就是value为Parameter类型的字典,其可以通过data()函数获得其具体参数,参数类型为NDArray,如:
arraya = net_params['stage4_unit3_bn2_beta'].data()
arg_params,aux_params均是一个字典类型,他们的结构均为"参数名称":NDarray,如:
arrayb = arg_params['stage4_unit3_bn2_beta
需要说明的是:
inputs=mx.sym.var('data')
是使用静态图的方法生成一个输入节点名为'data' ,arg_params是主要参数如weights,aux_params是辅助参数主要是bias或者是batchnorm中的一些参数。
疑问:以上方法是对模型中的参数一个个的load,虽然已经装载进去了但是net_params内部的参数的shape扔然是None这是不解的地方,如下所示;
疑问5.29日解决,原因是:虽然模型函数已经加载了参数,但是mxnet模型推断机制是在模型进行一次前向计算(forward)后才完成,如下图所示:
SymbolBlock是继承于block有好多的Sequence的方法,其并不能使用,如net[0]因为其内部并没有__getitems__()函数所以这种访问模型内部参数字典的方法并不适用
第一种方法的import函数内容如下:
def imports(symbol_file, input_names, param_file=None, ctx=None):
"""Import model previously saved by `HybridBlock.export` or
`Module.save_checkpoint` as a SymbolBlock for use in Gluon.
Parameters
----------
symbol_file : str
Path to symbol file.
input_names : list of str
List of input variable names
param_file : str, optional
Path to parameter file.
ctx : Context, default None
The context to initialize SymbolBlock on.
Returns
-------
SymbolBlock
SymbolBlock loaded from symbol and parameter files.
Examples
--------
>>> net1 = gluon.model_zoo.vision.resnet18_v1(
... prefix='resnet', pretrained=True)
>>> net1.hybridize()
>>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
>>> out1 = net1(x)
>>> net1.export('net1', epoch=1)
>>>
>>> net2 = gluon.SymbolBlock.imports(
... 'net1-symbol.json', ['data'], 'net1-0001.params')
>>> out2 = net2(x)
"""
sym = symbol.load(symbol_file)
if isinstance(input_names, str):
input_names = [input_names]
inputs = [symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
ret.collect_params().load(param_file, ctx=ctx)
return ret
特殊情况下可以用到