yolov3修改替换onnx模型节点(Resize-->DConv)

1. 定位

找到将要替换的节点位置, model.graph.node中节点参数是受保护变量不能直接循环,需要通过索引访问, model.graph.node[i]

yolov3修改替换onnx模型节点(Resize-->DConv)_第1张图片

model = onnx.load('../onnx_model/yolov3.onnx')
node = model.graph.node

#查找模型中Resize算子节点位置
for i in range(len(node)):
    if node[i].op_type == 'Resize':
        node_rise = node[i]
        print(i)

2. 创建新节点

# 新建新节点(反卷积)
deconv1_node = onnx.helper.make_node(
    name="ConvTranspose_143",           #节点名称
    op_type="ConvTranspose",            #算子类型
    inputs=["639", 'deconv1.weight'],   #输入参数,可以为多个,如['X','W','b']
    outputs=["644"],                    #输出节点名字(下一层输入节点名称)
    output_padding=[1,1],           #输出特征图填充
    group=1,
    kernel_shape=[3,3],
    pads=[1,1,1,1],
    strides=[2,2],
)

3. 删除旧节点,插入新节点

# 删除老节点
old_node = model.graph.node[index]
model.graph.node.remove(old_node)
# 插入新节点
model.graph.node.insert(index, new_node[i])

4. 初始化输入参数(W,b,等)

# 权重shape
deconv1_weight_shape = [256, 256, 3, 3]
# 生成weight数据
weight1 = np.random.ranf(256 * 256 * 3 * 3).astype(np.float32)
#转换成onnx类型weight数据
deconv1_weight = helper.make_tensor('deconv1.weight', TensorProto.FLOAT, deconv1_weight_shape, weight1)

5.  初始化权重W

#将生成的onnx权重数据添加到 model.graph.initializer中
model.graph.initializer.append(deconv1_weight)

6. 保存到新模型

onnx.save(model, "my.onnx")

7. 注意

使用helper.make_tensor()生成onnx权重时,'deconv1.weight' 是生成的权重名称,也是model.graph.initializer中显示的名称,要和节点input中的输入名称一样,这样将权重添加到model.graph.initializer时才能被正确识别

yolov3修改替换onnx模型节点(Resize-->DConv)_第2张图片  yolov3修改替换onnx模型节点(Resize-->DConv)_第3张图片

helper.make_tensor('deconv1.weight', TensorProto.FLOAT, weight_shape, weight1)

8. 完整代码

import numpy as np
import onnx
from onnx import helper, TensorProto

# kernel 2x2的反卷积替换2倍的resize上采样(最近邻插值)

def node_replace(model, new_node:list, old_index:list, weight:list, output):
    # 目前仅支持替换为不含bias的反卷积算子
    """

    :param model:  onnx模型
    :param new_node: 需要替换的新节点信息列表
    :param old_index: 需要替换的旧节点索引列表
    :param weight: 新节点的权重列表
    :param output: 保存的新模型名字 ,如:my.onnx
    :return: None
    """
    for i, index in enumerate(old_index):
        # 删除老节点
        old_node = model.graph.node[index]
        model.graph.node.remove(old_node)
        # 添加新节点
        model.graph.node.insert(index, new_node[i])
        # 初始化权重
        model.graph.initializer.append(weight[i])
    onnx.save(model, output)

model = onnx.load('../onnx_model/yolov3.onnx')
# model = onnx.load('my.onnx')

node = model.graph.node

# 1.2搜索目标节点
for i in range(len(node)):
    if node[i].op_type == 'Resize':
        node_rise = node[i]
        print(i)
# print(node[159])


# 节点信息
deconv1_weight_shape = [256, 256, 2, 2]
deconv2_weight_shape = [128, 128, 2, 2]

# input_shape = [1, 256, 13, 13]
# deconv1_output_shape = [1, 256, 26, 26]
# deconv2_output_shape = [1, 128, 52, 52]

# 生成weight
weight1 = np.zeros([256,256,2,2]).astype(np.float32)
weight2 = np.zeros([128,128,2,2]).astype(np.float32)
# 初始化权值
for i in range(256):
    weight1[i, i, :, :] = 1
for i in range(128):
    weight2[i, i, :, :] = 1
print(weight1.shape)
print(weight2.shape)

deconv1_weight = helper.make_tensor('deconv1.weight', TensorProto.FLOAT, deconv1_weight_shape, weight1)
deconv2_weight = helper.make_tensor('deconv2.weight', TensorProto.FLOAT, deconv2_weight_shape, weight2)

# output = s*(in-1) + output_padding + ((k-1)*d + 1) - 2p
# # 新建新节点
deconv1_node = onnx.helper.make_node(
    name="ConvTranspose_143",
    op_type="ConvTranspose",
    inputs=["639", 'deconv1.weight'],
    outputs=["644"],
    output_padding=[0,0],
    group=1,
    kernel_shape=[2,2],
    pads=[0,0,0,0],
    strides=[2,2],
    dilations=[1,1]
)

# # 新建新节点
deconv2_node = onnx.helper.make_node(
    name="ConvTranspose_161",
    op_type="ConvTranspose",
    inputs=["667", 'deconv2.weight'],
    outputs=["672"],
    output_padding=[0,0],
    group=1,
    kernel_shape=[2,2],
    pads=[0,0,0,0],
    strides=[2,2],
    dilations=[1,1]
)

node_replace(model, [deconv1_node,deconv2_node],[142,159],[deconv1_weight,deconv2_weight], 'out.onnx')

# ##模型检查
try:
    onnx.checker.check_model(model)
except onnx.checker.ValidationError as e:
    print('The model is invalid: %s' % e)
else:
    print('The model is valid!')




你可能感兴趣的:(模型量化,python,开发语言)