MXNet转onnx问题点记录

MXNet转onnx问题点记录

  • MXnet转onnx时碰到的问题记录
    • 主要问题
    • 总结

MXnet转onnx时碰到的问题记录

最近将mxnet转onnx时碰到很多问题,在查找过程中发现解决方法,特记录下;转换过程主要参考如下链接中的解决方法
mxnet模型转onnx模型 作者liguiyuan112
Insightface中ArcFace MxNet2ONNX踩坑
MXNet Symbol Batch Normalization fix_gamma=True转ONNX方法

主要问题

  1. 转换时BN层出现错误
onnx.onnx_cpp2py_export.checker.ValidationError: Unrecognized attribute: spatial for operator BatchNormalization

==> Context: Bad node spec: input: "conv_1_conv2d" input: "conv_1_batchnorm_gamma" input: "conv_1_batchnorm_beta" input:"conv_1_batchnorm_moving_mean" input:"conv_1_batchnorm_moving_var" output: "conv_1_batchnorm" name: "conv_1_batchnorm" op_type: "BatchNormalization" attribute { name: "epsilon" f: 0.001 type: FLOAT } attribute { name: "momentum" f: 0.9  type: FLOAT } attribute { name: "spatial" i: 0 type: INT }

解决方法:参考mxnet模型转onnx模型 作者liguiyuan112 ,修改pip安装的mxnet文件夹下mxnet/contrib/onnx/mx2onnx/_op_translations.py中 359行左右的 spatial=0 注释掉即可

  1. 推理报错
    2.1. TypeError: 127.5 has type numpy.float32, but expected one of: int, long, float
    2.2. [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T)bound to different types (tensor (double) and tensor(float) in node (conv_1_conv2d).
    2.3. [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running PRelu node. Name:‘conv_1_relu’… Attempting to broadcast an axis by a dimension other than 1. 56 by 64
    参考Insightface中ArcFace MxNet2ONNX踩坑,其代码主要在github: zheshipinyinMc/arcface_retinaface_mxnet2onnx

  2. 推理结果不一致
    上述1和2步骤完成后,可以得到模型,但实际推理过程中发现onnxruntime推理结果和mxnet推理结果不一致,检查每层结果,发现在batchnorm层后两者之间的输出结果不一致,查询后,发现是mxnet中当batchnorm层中的参数fix_gamma为True时,其前向推理时 gamma参数不参与计算,所以需要将gamma参数的值置为1,可参考如下代码,注意修改 这一行代码,if ‘batchnorm_gamma’ in k or ‘fc1_gamma’ in k: 将所有fix_gamma为True的修改为1,fix_gamma为False时不需要修改。调用完成后,再利用2中的代码完成模型的转换,其推理结果就一致了。

def get_model(ctx, image_size, model_str, layer):
   _vec = model_str.split(',')
   assert len(_vec) == 2
   prefix = _vec[0]
   epoch = int(_vec[1])
   print('loading', prefix, epoch)
   sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
   reshape_params = {}
   for k, v in arg_params.items():
       if 'batchnorm_gamma' in k or 'fc1_gamma' in k:            
           v[:]=1.0
       reshape_params[k] = v
   mx.model.save_checkpoint(prefix, epoch, sym, reshape_params, aux_params)

总结

本文记录作者在转换insightface中gender_age模型时遇到的问题,如有侵权,请联系删除。

你可能感兴趣的:(mxnet,深度学习)