自定义的 bert 模型导出 onnx 报错:TypeError: forward() takes 2 positional arguments but 4 were given

自定义的 bert 模型导出 onnx 报错:TypeError: forward takes 2 positional arguments but 4 were given

  • 自定义的 bert 模型导出 onnx 报错:TypeError: forward() takes 2 positional arguments but 4 were given
    • 导出代码
    • 错误提示
    • 核心错误
  • 解决方法
    • 查看源码
    • 修改源码

自定义的 bert 模型导出 onnx 报错:TypeError: forward() takes 2 positional arguments but 4 were given

导出代码

python export_pt_to_onnx.py

    text_encoder = get_text_encoder()
    text = '测试一下结果'
    
    # text encoder
    x = tokenizer([text])# 默认batch_size=1
    print('text_tokens:', x)
    # text_tokens: {'input_ids': tensor([[ 101, 1037, 3899, 2003, 2006, 1996, 5568,  102]]), 
    # 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 
    # 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
    
    input_names = ['input_ids', 'token_type_ids', 'attention_mask']
    output_names = ['output']
    text_encoder.eval()
    print(tuple(x.values()))
    
    opset_version = 15
    with torch.no_grad():
        dynamic_axes = {# 动态维度
            'input_ids': [0, 1],
            'attention_mask': [0, 1],
            'token_type_ids': [0, 1],
        }
        torch.onnx.export(text_encoder, 
                      tuple(x.values()), 
                      'onnx/text_encoder.onnx', 
                      input_names=input_names, 
                      output_names=output_names, 
                      opset_version=opset_version,
                      dynamic_axes=dynamic_axes,
                      )

错误提示

错误详细提示如下:

Traceback (most recent call last):
  File "/workspace/xx/export_pt_to_onnx.py", line 60, in 
    export()
  File "/workspace/xx/export_pt_to_onnx.py", line 36, in export
    torch.onnx.export(text_encoder, 
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 719, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 499, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 440, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 391, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

核心错误

  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

字面意思是:forward() 需要2个参数,但输入了4个参数

解决方法

查看源码

查看自己定义的模型类的源代码:

    def forward(self, x):
        out = self.base(**x).last_hidden_state
        ...

从上面可以看成输入参数只有1个 x,

修改源码

    def forward(self, input_ids, token_type_ids, attention_mask):
        out = self.base(input_ids, token_type_ids, attention_mask).last_hidden_state
        ...

1个参数就变成了3个参数了

再次运行,导出成功!

你可能感兴趣的:(自然语言处理,人工智能,深度学习,bert,人工智能,onnx)