TensorRT学习笔记--序列化(Serialize)和反序列化(Deserialize)

目录

前言

1--序列化和反序列化的概念

2--代码实现

2-1--序列化并保存模型

2-2--反序列化加载模型

2-3--完整代码


前言

        所有代码均基于 Tensor RT 8.2.5 版本,使用的是 Python 环境。

1--序列化和反序列化的概念

        将某个对象的信息转化成可以存储或者传输的信息,这个过程称为序列化;

        反序列化是序列化的相反过程,将信息还原为序列化前的状态;

        在 Pytorch 中,当序列化为 torch.save( ) 时,则反序列化可以是 torch.load( )

        在 Tensor RT 中,为了能够在 interface 的时候不需要重复编译 engine,倾向于将模型 序列化 成一个能够永久保存的 engine;当需要 interface 的时候,只需要通过简单的 反序列化 就能够快速加载 序列化保存好的模型 engine,节省部署开发的时间。

2--代码实现

2-1--序列化并保存模型

# 创建序列化engine
engine = builder.build_serialized_network(network, config)
   
# 保存序列化保存模型,便于后续直接调用
if True:
    saved_trt_path = "./serialize_fcn-resnet101.trt" # 序列化模型保存的地址
    with open(saved_trt_path, "wb") as f:
    f.write(engine) # 保存序列化模型

2-2--反序列化加载模型

# 反序列化加载模型
f = open(saved_trt_path, "rb") # 打开保存的序列化模型
runtime = trt.Runtime(TRT_LOGGER) 
engine = runtime.deserialize_cuda_engine(f.read()) # 反序列化加载模型
    
# 创建context用来执行推断
context = engine.create_execution_context()

2-3--完整代码

import pycuda
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt


if __name__ == "__main__":
    # 创建日志记录器
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    
    # 显式batch_size,batch_size有显式和隐式之分
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    
    # 创建builder,用于创建network
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(EXPLICIT_BATCH) # 创建network(初始为空)
    
    # 创建config
    config = builder.create_builder_config()
    profile = builder.create_optimization_profile() # 创建profile
    profile.set_shape("input", (1,3,256,256), (1,3,1026,1282), (1,3,1280,1536))  # 设置动态输入,"input"对应onnx模型的输入"name"
    #(1,3,256,256), (1,3,1026,1282), (1,3,1280,1536) 分别对应:最小尺寸、最佳尺寸、最大尺寸
    config.add_optimization_profile(profile)
    config.max_workspace_size = 1<<30 # 允许TensorRT使用1GB的GPU内存,<<表示左移,左移30位即扩大2^30倍,使用2^30 bytes即 1 GB
    
    # 创建parser用于解析模型
    parser = trt.OnnxParser(network, TRT_LOGGER)
    
    # 读取并解析模型
    onnx_model_file = "./fcn-resnet101.onnx" # Onnx模型的地址
    model = open(onnx_model_file, 'rb')
    if not parser.parse(model.read()): # 解析模型
        for error in range(parser.num_errors):
            print(parser.get_error(error)) # 打印错误(如果解析失败,根据打印的错误进行Debug)

    # 创建序列化engine
    engine = builder.build_serialized_network(network, config)
    
    # 保存序列化保存模型,便于后续直接调用
    if True:
        saved_trt_path = "./serialize_fcn-resnet101.trt" # 序列化模型保存的地址
        with open(saved_trt_path, "wb") as f:
            f.write(engine) # 保存序列化模型
     
    # 反序列化加载模型
    f = open(saved_trt_path, "rb") # 打开保存的序列化模型
    runtime = trt.Runtime(TRT_LOGGER) 
    engine = runtime.deserialize_cuda_engine(f.read()) # 反序列化加载模型
    
    # 创建context用来执行推断
    context = engine.create_execution_context()
    
    '''
    ...后续步骤
    '''

你可能感兴趣的:(TensorRT学习笔记,深度学习,人工智能)