1. 注意 xlm RE模型输入有个relation 是dict 类型,改成多个list输入。
2. RE 模型中有个bilinear层转onnx不支持,重新写这个层。
3. bilinear 前向代码
def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
y = torch.zeros((input1.shape[0],self.weight.shape[0]))
for k in range(self.weight.shape[0]):
buff = torch.matmul(input1, self.weight[k])
buff = buff * input2#torch.matmul(buff, input2)
buff = torch.sum(buff,dim=1)
y[:,k] = buff
if self.bias is not None:
y += self.bias
return y
4. onnx转换代码
import torch
from torch import nn
from transformers import AutoTokenizer, AutoConfig
from models.layoutxlm import LayoutXLMForRelationExtraction
import numpy as np
from models.layoutxlm.modeling_layoutxlm import LayoutXLMForRelationExtractionExport
model = LayoutXLMForRelationExtractionExport.from_pretrained('model_path')
tokenizer = AutoTokenizer.from_pretrained('mpdel_path')
dummy_model_input = tokenizer("This is a sample", return_tensors=None)
dummy_model_input['input_ids'] = torch.zeros((1,512),dtype=torch.int64) + 10
dummy_model_input['bbox'] = torch.zeros((1,512, 4), dtype=torch.int64) + 20
dummy_model_input['attention_mask'] = torch.zeros((1,512), dtype=torch.int64) + 1
dummy_model_input['start'] = torch.tensor(np.asarray([[0,3,5,7,9,-1]], dtype=np.int64).reshape((1, -1)))
# dummy_model_input['end'] = torch.tensor(np.asarray([3,5,7,9,15], dtype=np.int64).reshape((1, -1)))
dummy_model_input['labels'] = torch.tensor(np.asarray([[1, 2, 1, 2, 1]], dtype=np.int64).reshape((1, -1)))
dummy_model_input['head'] = torch.tensor(np.asarray([[0,2,4]], dtype=np.int64).reshape((1, -1)))
dummy_model_input['tail'] = torch.tensor(np.asarray([[1,3,3]], dtype=np.int64).reshape((1, -1)))
# export
model.eval()
torch.onnx.export(
model,
(dummy_model_input['input_ids'],
dummy_model_input['bbox'],
dummy_model_input['attention_mask'],
dummy_model_input['start'],
# dummy_model_input['end'],
dummy_model_input['labels'],
dummy_model_input['head'],
dummy_model_input['tail']),
f="./layoutxlm_re.onnx",
input_names=['input_ids','bbox','attention_mask', 'start', 'labels', 'head', 'tail'],
output_names=['logits'],
dynamic_axes={'input_ids': {0: 'batch_size'},
'bbox': {0: 'batch_size'},
'attention_mask': {0: 'batch_size'},
'start': {0: 'batch_size', 1: 'entity_len'},
# 'end': {0: 'batch_size', 1: 'entity_len'},
'labels': {0: 'batch_size', 1: 'entity_len'},
'head': {0: 'batch_size', 1: 'relation_len'},
'tail': {0: 'batch_size', 1: 'relation_len'},
'logits':{0: 'batch_size', 1:'relation_len'}
},
# do_constant_folding=True,
# verbose=True,
opset_version=13,
)
5. 模型转换后,因为模型代码中有个for循环,所以转出的onnx只能batch=1 推理。