参考以下文章:
https://blog.csdn.net/xxradon/article/details/104715524
https://blog.csdn.net/weixin_43945848/article/details/122486725
https://blog.csdn.net/ZhangLH66/article/details/121247815
我的诉求是删了红框的部分,开一个 https://netron.app/,看onnx层要删除的层的名字
比如这个sigmoid层,点击右边的name就是名字
名字列表放到我这里定义的x 里面,后面会按从大到小降序删除,如果反着删的话,列表排序的位置会变,就会删错层,这里需要注意
删完了之后应该还要把输出层指出来,onnx的输出可以作为节点被删除,但是创建需要走一下他的创建输出流程,输出维度放在了output_shape_map中,自行修改输出维度后,校验一下模型,没问题就可以进行下一步了
代码放在下面了
import onnx
from onnx import helper, checker
from onnx import TensorProto
import re
import argparse
model = "D:\project\Pythonproj\yolov5\yolo5\yolov5s.onnx"
# model = "D:\project\caffe\dockerfile\yolov5s-simple.onnx"
import onnx
onnx_model = onnx.load(model)
graph = onnx_model.graph
# print(graph)
node = graph.node
# node[213].output[0] = node[212].output[0]
# node[213].output[0] = node[213].input[0]
# for idx in graph.node:
# print(idx)
# graph.node[]
print(graph.output[0].type.tensor_type.shape)
# graph.output
# graph.output[1].type.tensor_type.elem_type = 1
# graph.output[2].name = "output2"
def createGraphMemberMap(graph_member_list):
member_map=dict()
for n in graph_member_list:
member_map[n.name]=n
return member_map
# x = {"Concat_304","Reshape_303","Reshape_267","Reshape_231","Sigmoid_218","Sigmoid_290","Sigmoid_254","Split_219","Split_255","Split_291","Mul_221","Mul_226","Mul_257","Mul_262","Mul_293","Mul_298","Mul_224","Mul_229","Mul_260","Mul_265","Mul_296","Mul_301","Add_222","Add_258","Add_294","Pow_228","Pow_264","Pow_300","Concat_230","Concat_266","Concat_302"}
# 待删除的层 开一个 https://netron.app/ 点击要删的层看name
x = {"Concat_250","Reshape_213","Reshape_249","Reshape_231","Sigmoid_199","Sigmoid_217","Sigmoid_235","Split_200","Split_218","Split_236","Mul_202","Mul_226","Mul_208","Mul_220","Mul_238","Mul_244","Mul_224","Mul_229","Mul_211","Mul_242","Mul_247","Mul_206","Add_222","Add_204","Add_240","Pow_227","Pow_209","Pow_245","Concat_230","Concat_212","Concat_248"}
de = []
num = 0
#
node_map = createGraphMemberMap(graph.node)
output_map = createGraphMemberMap(graph.output)
graph.output.remove(output_map["output0"])
new_output_node_names = ["output0","output1","output2"]
output_shape_map = [[1,3,80,80,85],[1,3,40,40,85],[1,3,20,20,85]]
for i in range(3):
new_nv = helper.make_tensor_value_info(new_output_node_names[i], TensorProto.FLOAT, output_shape_map[i])
graph.output.extend([new_nv])
output_map = createGraphMemberMap(graph.output)
for i in range(len(graph.node)):
if node[i].name in x:
de.append(i)
num = num+1
de.sort()
de.reverse()
for i in range(num):
graph.node.remove(graph.node[de[i]])
print("graph_output:", graph.output)
for i in range(len(graph.node)):
# if node[i].name == "Transpose_198":
# node[i].output[0]="output0"
# if node[i].name == "Transpose_234":
# node[i].output[0]="output1"
# if node[i].name == "Transpose_270":
# node[i].output[0]="output2"
if node[i].name == "Transpose_198":
node[i].output[0]="output0"
if node[i].name == "Transpose_216":
node[i].output[0]="output1"
if node[i].name == "Transpose_234":
node[i].output[0]="output2"
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model,"del_detect.onnx")