在进行模型融合时,我们通常需要将多个模型的层合并为一个自定义节点,以提高模型推理效率和减少推理时间。这里就介绍如何用onnx.helper将多层模型融合为一个自定义节点,并将权重参数以constant形式加载到onnx模型中的具体步骤:
import onnx
import onnx.helper as helper
import numpy as np
model_path = 'original_model.onnx'
model = onnx.load(model_path)
new_node = helper.make_node(
'custom_node',
inputs=['input_tensor'],
outputs=['output_tensor'],
attribute=value_list
)
其中,'custom_node’为自定义节点的名称,'input_tensor’为输入节点名称,'output_tensor’为输出节点名称,value_list为节点属性列表。
# 假设需要融合的层为第2层和第3层
layer2 = model.graph.node[2]
layer3 = model.graph.node[3]
# 提取权重参数
weight2 = np.array(onnx.numpy_helper.to_array(layer2.attribute[0].t))
bias2 = np.array(onnx.numpy_helper.to_array(layer2.attribute[1].t))
weight3 = np.array(onnx.numpy_helper.to_array(layer3.attribute[0].t))
bias3 = np.array(onnx.numpy_helper.to_array(layer3.attribute[1].t))
# 创建常量节点,用于存储权重参数
const_node2 = helper.make_node(
'Constant',
inputs=[],
outputs=['const_weight2'],
value=onnx.helper.make_tensor(
name='const_weight2',
data_type=onnx.TensorProto.FLOAT,
dims=weight2.shape,
vals=weight2.flatten().tolist(),
),
)
const_node_bias2 = helper.make_node(
'Constant',
inputs=[],
outputs=['const_bias2'],
value=onnx.helper.make_tensor(
name='const_bias2',
data_type=onnx.TensorProto.FLOAT,
dims=bias2.shape,
vals=bias2.flatten().tolist(),
),
)
const_node3 = helper.make_node(
'Constant',
inputs=[],
outputs=['const_weight3'],
value=onnx.helper.make_tensor(
name='const_weight3',
data_type=onnx.TensorProto.FLOAT,
dims=weight3.shape,
vals=weight3.flatten().tolist(),
),
)
const_node_bias3 = helper.make_node(
'Constant',
inputs=[],
outputs=['const_bias3'],
value=onnx.helper.make_tensor(
name='const_bias3',
data_type=onnx.TensorProto.FLOAT,
dims=bias3.shape,
vals=bias3.flatten().tolist(),
),
)
其中,'Constant’为常量节点的类型,‘const_weight2’、‘const_bias2’、‘const_weight3’、'const_bias3’为常量节点的名称,weight2、bias2、weight3、bias3为权重参数。
# 将节点加入图中
model.graph.node.insert(0, const_node_bias3)
model.graph.node.insert(0, const_node3)
model.graph.node.insert(0, const_node_bias2)
model.graph.node.insert(0, const_node2)
model.graph.node.insert(len(model.graph.node), new_node)
# 删除原始模型的融合层
del model.graph.node[3]
del model.graph.node[2]
最终,我们得到了一个新的onnx模型,其中多个层已经被融合成了一个自定义节点,并且权重参数以常量节点的形式加载进来了。可以使用onnx模型优化器或TensorRT等库对模型进行优化和加速。
onnx.helper修改onnx模型,多层融合替换为自定义节点,及将多层的权重参数以constant 添加到onnx中,
下面以transformer为例,将其融合成为一个node节点,并将其权重参数添加到节点中,生成一个新的onnx模型。完整示例代码:
import onnx
import torch
import torch.nn as nn
import onnxruntime as ort
import numpy as np
from onnx import helper, shape_inference, AttributeProto, TensorProto, GraphProto
from onnx import numpy_helper
import onnx_graphsurgeon as gs
from onnx_export import LaneViT
# load weights
ckpt_path = 'best_model_0530.pth'
ckpt = torch.load(ckpt_path)
vit_weights = {}
for k, v in ckpt['model_state_dict'].items():
# print(k, v.shape)
if 'cbam' in k:
continue
if 'projector' in k:
# print(k, v.shape)
if "projector.resnet.conv.weight" in k:
vit_weights['projector.resnet.out.weight'] = v
continue
k = k.replace('projector.resnet.resnet', 'projector.resnet.model')
# print(k)
vit_weights[k] = v
else:
vit_weights[k] = v
def load_torch_weights(ckpt_path):
ckpt = torch.load(ckpt_path)
plugin_key_weight = {}
for k, v in ckpt['model_state_dict'].items():
if 'backbone' in k:
# print(k, v.shape)
if 'position_embeddings' in k:
plugin_key_weight['Transformer/posembed_input/pos_embedding'] = v.contiguous()
elif 'patch_embeddings.weight' in k:
plugin_key_weight['embedding/kernel'] = v.contiguous()
elif 'patch_embeddings.bias' in k:
plugin_key_weight['embedding/bias'] = v.contiguous()
elif 'encoder_norm.weight' in k:
plugin_key_weight['Transformer/encoder_norm/scale'] = v.contiguous()
elif 'encoder_norm.bias' in k:
plugin_key_weight['Transformer/encoder_norm/bias'] = v.contiguous()
elif 'layer' in k:
k_list = k.split('.')
layer_num = int(k_list[4])
layer_name = '.'.join(k_list[5:])
# print(layer_num, layer_name)
sub_key_root = 'Transformer/encoderblock_'
norm_dir_0 = '/LayerNorm_0'
mha_dir_1 = '/MultiHeadDotProductAttention_1/'
norm_dir_2 = '/LayerNorm_2'
mlp_dir_3 = '/MlpBlock_3'
if 'attention_norm.weight' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = v
elif 'attention_norm.bias' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = v
elif 'ffn_norm.weight' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = v
elif 'ffn_norm.bias' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = v
elif 'fc1.weight' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = v.t().contiguous()
elif 'fc1.bias' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = v.t().contiguous()
elif 'fc2.weight' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = v.t().contiguous()
elif 'fc2.bias' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = v.t().contiguous()
elif 'attn' in layer_name:
sub_attn_name = layer_name.split('.')[1]
hidden_size = 512
num_heads = 16
attn_head = 32 # hidden_size / num_heads
if 'weight' in layer_name:
if 'out' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/kernel'] = \
v.view(hidden_size, 16, 32).permute(1,2,0)
else:
plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/kernel'] = \
v.t().view(hidden_size, 16, 32)
elif 'bias' in layer_name:
if 'out' in layer_name:
plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/bias'] = \
v
else:
plugin_key_weight[sub_key_root+ str(layer_num)+ mha_dir_1 + sub_attn_name+ '/bias'] = \
v.view(16, 32)
else:
print('[ERROR] attn wrong ? The name is {}'.format(k))
exit()
else:
print('[ERROR] layer_name wrong ? The name is {}'.format(k))
continue
# for k, v in plugin_key_weight.items():
# print(k, v.shape)
return plugin_key_weight
plugin_key_weight = load_torch_weights(ckpt_path)
# Load the trained model
model_file = 'custom_model_simp_0530.onnx'
model = onnx.load(model_file)
# print(model.graph.input)
graph = model.graph
# weights = {}
# for node in model.graph.initializer:
# weights[node.name] = numpy_helper.to_array(node)
# print('ddd ', node.name, weights[node.name].shape)
def get_node(onnx_model, node_name):
for node_id, node in enumerate(onnx_model.graph.node):
node.doc_string = "None"
if node_name in node.name:
return node, node_id
return None, 0
def create_custom_node(node, out_idx, weights):
node_input_name = node.input[0]
node_out_name = node.output[out_idx]
node.output[out_idx] = node_out_name + "_plugin_out"
inputs_name = [node_input_name]
new_nodes = []
count = 0
for k, v in weights.items():
count += 1
print(count, k, v.shape, v.reshape(-1).shape)
constant_node = helper.make_node(
op_type="Constant",
inputs=[],
outputs=[k],
name=k,
value=helper.make_tensor(name=k,
data_type=TensorProto.FLOAT,
dims=v.shape,
vals=v.reshape(-1))
)
inputs_name.append(k)
new_nodes.append(constant_node)
# if count > 5:
# break
custom_node = helper.make_node(op_type=node.name.split('_')[0],
inputs=inputs_name,
outputs=[node_out_name],
name=node.name,
doc_string="custom plugin.",
domain="custom")
img_size = helper.make_attribute("img_size", 144)
patch_size = helper.make_attribute("patch_size", 8)
in_chans = helper.make_attribute("in_chans", 64)
embed_dim = helper.make_attribute("embed_dim", 512)
num_heads = helper.make_attribute("num_heads", 16)
inter_size = helper.make_attribute("inter_size", 2048)
layer_num = helper.make_attribute("layer_num", 3)
with_cls_token = helper.make_attribute("with_cls_token", 0)
custom_node.attribute.extend([img_size, patch_size, in_chans, embed_dim, num_heads, inter_size, layer_num, with_cls_token])
new_nodes.append(custom_node)
return new_nodes
node, node_id = get_node(model, "TransformerPlugin")
print(node_id, node)
new_nodes = create_custom_node(node, 0, plugin_key_weight)
model.graph.node.remove(node)
for new_node in reversed(new_nodes):
model.graph.node.insert(node_id, new_node)
onnx.checker.check_model(model)
shape_inference.infer_shapes(model)
onnx.save(model, "./model/fused_model_0530.onnx")