ONNX 删除指定的结点

import onnx
import sys
from onnx import helper

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("usage:program input-model output-model")
        exit()

    input_model = sys.argv[1]
    output_model = sys.argv[2]
    onnx_model = onnx.load(input_model)

    graph = onnx_model.graph
    nodes = graph.node
    initializer = graph.initializer
    inputs= graph.input
    outputs = graph.output

    inputs[0].name = inputs[0].name.replace(".","_")
    outputs[0].name = outputs[0].name.replace("onnx::","")


    remove_list = []
    conn_node = 0


    for i in range(len(nodes)):
        for j in range(len(nodes[i].input)):
            if  "onnx::"  in nodes[i].input[j]:
                nodes[i].input[j] = nodes[i].input[j].replace("onnx::","")
            if "." in nodes[i].input[j]:
                nodes[i].input[j] &

你可能感兴趣的:(ONNX,python,服务器,linux)