

import os
import sys
import torch
import numpy as np

from feat.model import ResNet  # 导入自己的模型类

def load_checkpoint(checkpoint_file, model):
    """Loads the checkpoint from the given file."""
    err_str = "Checkpoint '{}' not found"
    assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
    checkpoint = torch.load(checkpoint_file, map_location="cpu")
    return checkpoint["epoch"]

if __name__ == '__main__':

    os.environ['CUDA_VISIBLE_DEVICES']='0'   # 设置运行显卡号


    # init model
    model = ResNet()
    load_checkpoint(model_filename, model)
    model = model.cuda()

    onnx_name = 'resnet.onnx'  # 输出onnx文件
    example = torch.randn((1,3,224,224))  # 模型输入大小
    example = example.cuda()
    input_names = ["input"]
    output_names = ["outputs"]
    dynamic_axes = {"input": {0: "batch_size"}, "outputs": {0: "batch_size"}}
    # 模型转换并保存
    torch.onnx.export(model, example,onnx_name, opset_version=12, input_names=input_names, output_names=output_names, dynamic_axes=None)


import os
import sys
import torch
import numpy as np
import onnxruntime
import time

if __name__ == '__main__':

    os.environ['CUDA_VISIBLE_DEVICES']='0'   # 设置运行显卡号


    # init model
    model = ResNet()
    load_checkpoint(model_filename, model)
    model = model.cuda()
    session = onnxruntime.InferenceSession(onnx_name,providers=['CUDAExecutionProvider'])
    img = np.random.randn(1,3,224,224).astype(np.float32) # 随机输出
    t1 = time.time()
    onnx_preds =, {"input": img})
    print("onnx preds result: ", onnx_preds)
    t2 = time.time()
    pth_preds = model(torch.from_numpy(img).cuda())
    print("pth preds result: ", pth_preds)
    t3 = time.time()


onnx preds res:  [array([[-0.13128008,  0.04037811,  0.0529038 ,  0.101323  , -0.03352938,                                                                                          [43/1903]
         0.03099938,  0.06380229, -0.03544223, -0.03368076,  0.06361518,                                                                                                                     
        -0.00668521, -0.01996843, -0.0132075 , -0.03448019,  0.17793381,                                                                                                                     
         0.08131739,  0.10232763, -0.09122676,  0.01173838,  0.03181053,                                                                                                                     
        -0.05899123,  0.01569226, -0.04734752, -0.12551421,  0.00686131,                                                                                                                     
        -0.00749457, -0.03729884,  0.05349742,  0.0304895 ,  0.02956274,                                                                                                                     
         0.00393172,  0.00196273,  0.01296113, -0.03985897, -0.06289426,                                                                                                                     
        -0.0825834 , -0.28903952,  0.02842386, -0.1718263 , -0.05555207,                                                                                                                     
        -0.03707219,  0.10904352,  0.06582819,  0.04960179,  0.01508415,                                                                                                                     
         0.05469472,  0.28663486,  0.1183752 , -0.06070469, -0.05200525,                                                                                                                     
        -0.03477468, -0.06193898, -0.04432139,  0.0843045 , -0.12080704,                                                                                                                     
         0.00163073, -0.08544722,  0.11994477,  0.02619292,  0.05066012,                                                                                                                     
        -0.00332941, -0.1488586 ,  0.07936171,  0.06203181, -0.0645356 ,                                                                                                                     
        -0.07661135, -0.05883927, -0.00459472, -0.06721105, -0.02880175,                                                                                                                     
        -0.00337263, -0.00927516,  0.03289868,  0.10054352, -0.09545278,                                                                                                                     
        -0.0216963 ,  0.11413048, -0.04580398,  0.02614305, -0.08269466,                                                                                                                     
         0.01835637,  0.17654261,  0.0573773 , -0.06440263,  0.01176349,                                                                                                                     
         0.00998674,  0.02840159,  0.14086637, -0.02473863,  0.05228964,                                                                                                                     
        -0.03329878, -0.02751228, -0.04788758,  0.1546051 ,  0.05838795,                                                                                                                     
        -0.02351469, -0.01315547, -0.13732813, -0.08146078,  0.01943143,                                                                                                                     
        -0.08991284,  0.14222968, -0.14729632,  0.24547395, -0.05293949,                                                                                                                     
         0.04446511,  0.05436133, -0.09403729, -0.0900671 ,  0.04516568,                                                                                                                     
         0.10035874, -0.03281724,  0.19480802, -0.11344203, -0.02487336,                                                                                                                     
        -0.08126407, -0.00491623,  0.04313428, -0.10474856, -0.11427435,                                                                                                                     
        -0.01765379, -0.04613522,  0.08338863,  0.00564523,  0.14067101,                                                                                                                     
         0.05428562,  0.12530491, -0.2503076 ]], dtype=float32)]                                                                                                                             
pth preds res:  tensor([[-0.1313,  0.0404,  0.0529,  0.1013, -0.0335,  0.0310,  0.0638, -0.0354,                                                                                             
         -0.0337,  0.0636, -0.0067, -0.0200, -0.0132, -0.0345,  0.1779,  0.0813,                                                                                                             
          0.1023, -0.0912,  0.0117,  0.0318, -0.0590,  0.0157, -0.0473, -0.1255,                                                                                                             
          0.0069, -0.0075, -0.0373,  0.0535,  0.0305,  0.0296,  0.0039,  0.0020,                                                                                                             
          0.0130, -0.0399, -0.0629, -0.0826, -0.2890,  0.0284, -0.1718, -0.0556,                                                                                                             
         -0.0371,  0.1090,  0.0658,  0.0496,  0.0151,  0.0547,  0.2866,  0.1184,                                                                                                             
         -0.0607, -0.0520, -0.0348, -0.0619, -0.0443,  0.0843, -0.1208,  0.0016,                                                                                                             
         -0.0854,  0.1199,  0.0262,  0.0507, -0.0033, -0.1489,  0.0794,  0.0620,                                                                                                             
         -0.0645, -0.0766, -0.0588, -0.0046, -0.0672, -0.0288, -0.0034, -0.0093,                                                                                                             
          0.0329,  0.1005, -0.0955, -0.0217,  0.1141, -0.0458,  0.0261, -0.0827,                                                                                                             
          0.0184,  0.1765,  0.0574, -0.0644,  0.0118,  0.0100,  0.0284,  0.1409,                                                                                                             
         -0.0247,  0.0523, -0.0333, -0.0275, -0.0479,  0.1546,  0.0584, -0.0235,                                                                                                             
         -0.0132, -0.1373, -0.0815,  0.0194, -0.0899,  0.1422, -0.1473,  0.2455,                                                                                                             
         -0.0529,  0.0445,  0.0544, -0.0940, -0.0901,  0.0452,  0.1004, -0.0328,                                                                                                             
          0.1948, -0.1134, -0.0249, -0.0813, -0.0049,  0.0431, -0.1047, -0.1143,                                                                                                             
         -0.0177, -0.0461,  0.0834,  0.0056,  0.1407,  0.0543,  0.1253, -0.2503]],                                                                                                           
       device='cuda:0', grad_fn=)                                                                                                                                              
onnx cost time:  0.0062367916107177734  pth cost time:  0.030622243881225586


import os
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt_runtime = trt.Runtime(TRT_LOGGER)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

def get_engine(input_shape, onnx_file_path = "", engine_file_path=""):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as config:
            # builder.max_workspace_size = 1 << 32 # 256MiBs
            config.max_workspace_size = 1 << 33  # 1024MB
            # config.set_flag(trt.BuilderFlag.FP16) # 使用Fp16精度,如果使用FP32需要屏蔽这一句。
            builder.max_batch_size = 1
            # Parse model file
            if not os.path.exists(onnx_file_path):
                print('ONNX file {} not found, please run torch2onnx first to generate it.'.format(onnx_file_path))
            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                if not parser.parse(
                    print ('ERROR: Failed to parse the ONNX file.')
                    for error in range(parser.num_errors):
                        print (parser.get_error(error))
                    return None
            # The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1
            network.get_input(0).shape = input_shape
            print('Completed parsing of ONNX file')
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
            # config = trt.IBuilderConfig(max_workspace_size = 1 << 32)
            # config.
            engine = builder.build_engine(network, config)
            print("Completed creating Engine")
            with open(engine_file_path, "wb") as f:
            return engine
    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(
        return build_engine()

if __name__ == '__main__':
    onnx_file = 'resnet.onnx'
    engin_file = 'resnet.engine'
    input_shape = [1, 3, 224, 224]
    get_engine(input_shape, onnx_file, engin_file)


import os
import sys
import cv2
import copy
import torch
import numpy as np
import time
import onnxruntime
import pycuda.driver as cuda
import tensorrt as trt

TRT_LOGGER = trt.Logger()
import trt_common
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if sys.getdefaultencoding() != 'utf-8':

# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem): = host_mem
        self.device = device_mem
    def __str__(self):
        return "Host:\n" + str( + "\nDevice:\n" + str(self.device)
    def __repr__(self):
        return self.__str__()

def get_engine(engine_file_path):
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(

# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
            outputs.append(HostDeviceMem(host_mem, device_mem))
    return inputs, outputs, bindings, stream

# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference_v2(context, bindings, inputs, outputs, stream):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device,, stream) for inp in inputs]
    # Run inference.
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(, out.device, stream) for out in outputs]
    # Synchronize the stream
    # Return only the host outputs.
    return [ for out in outputs]

if __name__ == '__main__':
    onnx_name = 'resnet.onnx'
    trt_name = 'resnet.engine'

    session = onnxruntime.InferenceSession(onnx_name,providers=['CUDAExecutionProvider'])

    import pycuda.autoprimaryctx
    engine = get_engine(trt_name)
    context = engine.create_execution_context()
    inputs, outputs, bindings, stream = allocate_buffers(engine)

    img = cv2.imread('test.jpg')
    img = cv2.resize(img, (224,224))
    img = img.transpose([2,0,1]).astype(np.float32)
    img = np.expand_dims(img, axis=0)
    t1 = time.time()
    onnx_preds =, {"input": img})
    #print("onnx_preds: ", onnx_preds)
    t2 = time.time()

    inputs[0].host = np.ascontiguousarray(img)
    trt_outputs = do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
    data = copy.deepcopy(trt_outputs[0])
    #print("preds: ", data)
    t3 = time.time()
    print("onnx: ", t2-t1, " trt: ", t3-t2)



ERROR: Failed to parse the ONNX file.

In node 84 (importConv): UNSUPPORTED_NODE: Assertion failed: && "The bias tensor is required to be an initializer for the Conv operator."


pip install onnx-simplifier


import onnx
from onnxsim import simplify

onnx_model = onnx.load('resnet.onnx')
model_simp, check = simplify(onnx_model), 'resnet_sim.onnx')


ValueError: ndarray is not contiguous


数组不连续,使用np.ascontiguousarray(img) 处理数组

inputs[0].host = np.ascontiguousarray(img)


Error Code 1: Myelin (Compiled against cuBLASLt but running against cuBLASLt


tensorrt 和 torch同时使用调用了不同版本的,不同同时使用。tensorrt和onnxruntime同时使用也会发生。
