使用onnxruntime推理Bert模型

Bert模型类别:onnx
输入输出数据格式:.npz

import onnxruntime
import numpy as np
import os

# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession('bert-base-uncased_final.onnx')

# 指定输入文件夹和输出文件夹
input_folder = ''
output_folder = ''

# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)

# 遍历输入文件
input_files = os.listdir(input_folder)
for input_file in input_files:
    if input_file.endswith('.npz'):
        input_path = os.path.join(input_folder, input_file)
        output_path = os.path.join(output_folder, input_file)
        print('input path:', input_path)
        # 加载 npz 格式的输入数据
        input_data = np.load(input_path)
        input_ids = input_data['input_ids']
        attention_mask = input_data['attention_mask']
        token_type_ids = input_data['token_type_ids']
        
        # 执行推理
        input_dict = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids
        }
        outputs = ort_session.run(None, input_dict)
        
        # 获取推理结果
        output_start_logits = outputs[0]
        output_end_logits = outputs[1]
        
        # 保存推理结果为 npz 格式
        output_data = {
            'output_start_logits': output_start_logits,
            'output_end_logits': output_end_logits
        }
        np.savez(output_path, **output_data)
        
        print('output path:', output_path)

你可能感兴趣的:(bert,人工智能,深度学习)