onnx.helper修改onnx模型融合多层为自定义节点并以constant形式加载权重参数到onnx

参考示例

在进行模型融合时,我们通常需要将多个模型的层合并为一个自定义节点,以提高模型推理效率和减少推理时间。这里就介绍如何用onnx.helper将多层模型融合为一个自定义节点,并将权重参数以constant形式加载到onnx模型中的具体步骤:

  1. 导入需要的模块
import onnx
import onnx.helper as helper
import numpy as np
  1. 加载原始onnx模型
model_path = 'original_model.onnx'
model = onnx.load(model_path)
  1. 定义新的自定义节点
new_node = helper.make_node(
    'custom_node',
    inputs=['input_tensor'],
    outputs=['output_tensor'],
    attribute=value_list
)

其中,'custom_node’为自定义节点的名称,'input_tensor’为输入节点名称,'output_tensor’为输出节点名称,value_list为节点属性列表。

  1. 找到需要融合的层,并提取权重参数
# 假设需要融合的层为第2层和第3层
layer2 = model.graph.node[2]
layer3 = model.graph.node[3]

# 提取权重参数
weight2 = np.array(onnx.numpy_helper.to_array(layer2.attribute[0].t))
bias2 = np.array(onnx.numpy_helper.to_array(layer2.attribute[1].t))
weight3 = np.array(onnx.numpy_helper.to_array(layer3.attribute[0].t))
bias3 = np.array(onnx.numpy_helper.to_array(layer3.attribute[1].t))
  1. 创建新的常量节点,并将权重参数赋值给它
# 创建常量节点,用于存储权重参数
const_node2 = helper.make_node(
    'Constant',
    inputs=[],
    outputs=['const_weight2'],
    value=onnx.helper.make_tensor(
        name='const_weight2',
        data_type=onnx.TensorProto.FLOAT,
        dims=weight2.shape,
        vals=weight2.flatten().tolist(),
    ),
)
const_node_bias2 = helper.make_node(
    'Constant',
    inputs=[],
    outputs=['const_bias2'],
    value=onnx.helper.make_tensor(
        name='const_bias2',
        data_type=onnx.TensorProto.FLOAT,
        dims=bias2.shape,
        vals=bias2.flatten().tolist(),
    ),
)
const_node3 = helper.make_node(
    'Constant',
    inputs=[],
    outputs=['const_weight3'],
    value=onnx.helper.make_tensor(
        name='const_weight3',
        data_type=onnx.TensorProto.FLOAT,
        dims=weight3.shape,
        vals=weight3.flatten().tolist(),
    ),
)
const_node_bias3 = helper.make_node(
    'Constant',
    inputs=[],
    outputs=['const_bias3'],
    value=onnx.helper.make_tensor(
        name='const_bias3',
        data_type=onnx.TensorProto.FLOAT,
        dims=bias3.shape,
        vals=bias3.flatten().tolist(),
    ),
)

其中,'Constant’为常量节点的类型,‘const_weight2’、‘const_bias2’、‘const_weight3’、'const_bias3’为常量节点的名称,weight2、bias2、weight3、bias3为权重参数。

  1. 将新的自定义节点、常量节点和原始模型的其余部分进行合并
# 将节点加入图中
model.graph.node.insert(0, const_node_bias3)
model.graph.node.insert(0, const_node3)
model.graph.node.insert(0, const_node_bias2)
model.graph.node.insert(0, const_node2)
model.graph.node.insert(len(model.graph.node), new_node)

# 删除原始模型的融合层
del model.graph.node[3]
del model.graph.node[2]

最终,我们得到了一个新的onnx模型,其中多个层已经被融合成了一个自定义节点,并且权重参数以常量节点的形式加载进来了。可以使用onnx模型优化器或TensorRT等库对模型进行优化和加速。

transformer实例

onnx.helper修改onnx模型,多层融合替换为自定义节点,及将多层的权重参数以constant 添加到onnx中,
下面以transformer为例,将其融合成为一个node节点,并将其权重参数添加到节点中,生成一个新的onnx模型。完整示例代码:

import onnx
import torch
import torch.nn as nn
import onnxruntime as ort
import numpy as np
from onnx import helper, shape_inference, AttributeProto, TensorProto, GraphProto
from onnx import numpy_helper
import onnx_graphsurgeon as gs
from onnx_export import LaneViT


# load weights
ckpt_path = 'best_model_0530.pth'
ckpt = torch.load(ckpt_path)
vit_weights = {}
for k, v in ckpt['model_state_dict'].items():
    # print(k, v.shape)
    if 'cbam' in k:
        continue
    if 'projector' in k:
        # print(k, v.shape)
        if "projector.resnet.conv.weight" in k:
            vit_weights['projector.resnet.out.weight'] = v
            continue
        k = k.replace('projector.resnet.resnet', 'projector.resnet.model')
        # print(k)
        vit_weights[k] = v 
    else:
        vit_weights[k] = v 


def load_torch_weights(ckpt_path):
    ckpt = torch.load(ckpt_path)
    plugin_key_weight = {}
    for k, v in ckpt['model_state_dict'].items():
        if 'backbone' in k:
            # print(k, v.shape)
            if 'position_embeddings' in k:
                plugin_key_weight['Transformer/posembed_input/pos_embedding'] = v.contiguous()
            elif 'patch_embeddings.weight' in k:
                plugin_key_weight['embedding/kernel'] = v.contiguous()
            elif 'patch_embeddings.bias' in k:
                plugin_key_weight['embedding/bias'] = v.contiguous()
            elif 'encoder_norm.weight' in k:
                plugin_key_weight['Transformer/encoder_norm/scale'] = v.contiguous()
            elif 'encoder_norm.bias' in k:
                plugin_key_weight['Transformer/encoder_norm/bias'] = v.contiguous()

            elif 'layer' in k:
                k_list = k.split('.')
                layer_num = int(k_list[4])      
                layer_name = '.'.join(k_list[5:])
                # print(layer_num, layer_name)
                sub_key_root = 'Transformer/encoderblock_'
                norm_dir_0 = '/LayerNorm_0'
                mha_dir_1 = '/MultiHeadDotProductAttention_1/'
                norm_dir_2 = '/LayerNorm_2'
                mlp_dir_3 = '/MlpBlock_3'
                if 'attention_norm.weight' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = v
                elif 'attention_norm.bias' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = v
                elif 'ffn_norm.weight' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = v
                elif 'ffn_norm.bias' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = v
                elif 'fc1.weight' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = v.t().contiguous()
                elif 'fc1.bias' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = v.t().contiguous()
                elif 'fc2.weight' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = v.t().contiguous()
                elif 'fc2.bias' in layer_name:
                    plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = v.t().contiguous()
                elif 'attn' in layer_name:
                    sub_attn_name = layer_name.split('.')[1]
                    hidden_size = 512
                    num_heads = 16
                    attn_head = 32 # hidden_size / num_heads
                    if 'weight' in layer_name:
                        if 'out' in layer_name:
                            plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/kernel'] = \
                                v.view(hidden_size, 16, 32).permute(1,2,0)
                        else:
                            plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/kernel'] = \
                                v.t().view(hidden_size, 16, 32)
                    elif 'bias' in layer_name:
                        if 'out' in layer_name:
                            plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/bias'] = \
                                v
                        else:
                            plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/bias'] = \
                                v.view(16, 32)

                else:
                    print('[ERROR] attn wrong ? The name is {}'.format(k))
                    exit()
            else:
                print('[ERROR] layer_name wrong ? The name is {}'.format(k))
                continue

    # for k, v in plugin_key_weight.items():
    #     print(k, v.shape)
    return plugin_key_weight

plugin_key_weight = load_torch_weights(ckpt_path)

# Load the trained model
model_file = 'custom_model_simp_0530.onnx'
model = onnx.load(model_file)
# print(model.graph.input)
graph = model.graph

# weights = {}
# for node in model.graph.initializer:
#     weights[node.name] = numpy_helper.to_array(node)
#     print('ddd ', node.name, weights[node.name].shape)

def get_node(onnx_model, node_name):
    for node_id, node in enumerate(onnx_model.graph.node):
        node.doc_string = "None"
        if node_name in node.name:
            return node, node_id
    return None, 0

def create_custom_node(node, out_idx, weights):
    node_input_name = node.input[0]
    node_out_name = node.output[out_idx]
    node.output[out_idx] = node_out_name + "_plugin_out"
 
    inputs_name = [node_input_name]
    new_nodes = []
    count = 0
    for k, v in weights.items():
        count += 1
        print(count, k, v.shape, v.reshape(-1).shape)
        constant_node = helper.make_node(
            op_type="Constant",
            inputs=[],
            outputs=[k],
            name=k,
            value=helper.make_tensor(name=k,
                                    data_type=TensorProto.FLOAT,
                                    dims=v.shape,
                                    vals=v.reshape(-1))
        )
        inputs_name.append(k)
        new_nodes.append(constant_node)
    # if count > 5:
    #     break
    custom_node = helper.make_node(op_type=node.name.split('_')[0],
                                    inputs=inputs_name,
                                    outputs=[node_out_name],
                                    name=node.name,
                                    doc_string="custom plugin.",
                                    domain="custom")
 
    img_size = helper.make_attribute("img_size", 144)
    patch_size = helper.make_attribute("patch_size", 8)
    in_chans = helper.make_attribute("in_chans", 64)
    embed_dim = helper.make_attribute("embed_dim", 512)
    num_heads = helper.make_attribute("num_heads", 16)
    inter_size = helper.make_attribute("inter_size", 2048)
    layer_num = helper.make_attribute("layer_num", 3)
    with_cls_token = helper.make_attribute("with_cls_token", 0)
    custom_node.attribute.extend([img_size, patch_size, in_chans, embed_dim, num_heads, inter_size, layer_num, with_cls_token])
    new_nodes.append(custom_node)
    return new_nodes

node, node_id = get_node(model, "TransformerPlugin")
print(node_id, node)
new_nodes = create_custom_node(node, 0, plugin_key_weight)
model.graph.node.remove(node)
for new_node in reversed(new_nodes):
    model.graph.node.insert(node_id, new_node)
 
onnx.checker.check_model(model)
shape_inference.infer_shapes(model)
onnx.save(model, "./model/fused_model_0530.onnx")


你可能感兴趣的:(python,人工智能,pytorch)