打印ONNX/TRT文件的所有节点

给一段python代码 可以查看.onnx文件的所有节点。

import onnx

def print_graph_nodes(model_path):
    # 加载 ONNX 模型
    model = onnx.load(model_path)

    # 遍历所有图节点并打印节点信息
    for node in model.graph.node:
        node_type = node.op_type
        node_name = node.name
        print(f'Node Type: {node_type}, Node Name: {node_name}')

if __name__ == '__main__':
    onnx_model_file = 'path/to/your/model.onnx'
    print_graph_nodes(onnx_model_file)

给一段python代码 可以查看.trt文件的所有节点

import tensorrt as trt

def print_network_nodes(trt_engine_path):
    # 加载TensorRT引擎
    with open(trt_engine_path, 'rb') as f, trt.Runtime(trt.Logger()) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())

    # 遍历所有网络层并打印节点信息
    for layer in engine:
        layer_type = layer.type
        layer_name = layer.name
        print(f'Layer Type: {layer_type}, Layer Name: {layer_name}')

if __name__ == '__main__':
    trt_engine_file = 'path/to/your/model.trt'
    print_network_nodes(trt_engine_file)

你可能感兴趣的:(c++,dnn,cnn)