python export_pt_to_onnx.py
text_encoder = get_text_encoder()
text = '测试一下结果'
# text encoder
x = tokenizer([text])# 默认batch_size=1
print('text_tokens:', x)
# text_tokens: {'input_ids': tensor([[ 101, 1037, 3899, 2003, 2006, 1996, 5568, 102]]),
# 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]),
# 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
input_names = ['input_ids', 'token_type_ids', 'attention_mask']
output_names = ['output']
text_encoder.eval()
print(tuple(x.values()))
opset_version = 15
with torch.no_grad():
dynamic_axes = {# 动态维度
'input_ids': [0, 1],
'attention_mask': [0, 1],
'token_type_ids': [0, 1],
}
torch.onnx.export(text_encoder,
tuple(x.values()),
'onnx/text_encoder.onnx',
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
)
错误详细提示如下:
Traceback (most recent call last):
File "/workspace/xx/export_pt_to_onnx.py", line 60, in
export()
File "/workspace/xx/export_pt_to_onnx.py", line 36, in export
torch.onnx.export(text_encoder,
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/__init__.py", line 305, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 118, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 719, in _export
_model_to_graph(model, args, verbose, input_names,
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 499, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 440, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 391, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given
File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given
字面意思是:forward() 需要2个参数,但输入了4个参数
查看自己定义的模型类的源代码:
def forward(self, x):
out = self.base(**x).last_hidden_state
...
从上面可以看成输入参数只有1个 x,
def forward(self, input_ids, token_type_ids, attention_mask):
out = self.base(input_ids, token_type_ids, attention_mask).last_hidden_state
...
1个参数就变成了3个参数了
再次运行,导出成功!