当PyTorch模型需要部署到服务时,为了提升访问速度,需要转换为TRT模型,再进行部署。在转换为TRT模型之前,需要将PyTorch参数模型(如pth.tar)转换为pt模型,使用jit形式。pt模型 = 参数模型(pth.tar) + 网络结构(如resnet50)。使用pt模型,可以简化使用方式,同时也方便转换为trt模型,进行轻量级部署。在转换函数中,包含验证逻辑,保证转换前后的模型效果一致,即输出不变。
以图像分类框架pytorch-image-models-my为例,将PyTorch的pth.tar模型转换为PT模型。
转换流程如下:
# 加载模型
model = timm.create_model(model_name=base_net, pretrained=False,
checkpoint_path=model_path, num_classes=num_classes)
if torch.cuda.is_available():
print('[Info] cuda on!!!')
model = model.cuda()
model.eval()
# 预测结果
print('[Info] 预测图像尺寸: {}'.format(img_rgb.shape))
img_tensor = self.preprocess_img(img_rgb, self.transform)
print('[Info] 模型输入: {}'.format(img_tensor.shape))
with torch.no_grad():
out = self.model(img_tensor)
将已加载的模型model,通过torch.jit.trace()
模拟输入dummy_input
,调用traced.save()
存储成pt模型,即:
注意输入尺寸dummy_shape
,用于生成模拟的input数据,需要与模型输入保持一致
注意是否支持GPU,即orch.cuda.is_available()
,判断环境是cuda还是cpu。
dummy_shape = (1, 3, 336, 336) # 不影响模型
print('[Info] dummy_shape: {}'.format(dummy_shape))
if torch.cuda.is_available():
model_type = "cuda"
else:
model_type = "cpu"
print('[Info] model_type: {}'.format(model_type))
dummy_input = torch.empty(dummy_shape,
dtype=torch.float32,
device=torch.device(model_type))
traced = torch.jit.trace(self.model, dummy_input)
pt_path = os.path.join(pt_folder_path, "{}_{}.pt".format(model_name, model_type))
traced.save(pt_path)
reload_script()
,即:with torch.no_grad():
standard_out = self.model(dummy_input)
print('[Info] standard_out: {}'.format(standard_out))
reload_script = torch.jit.load(pt_path)
with torch.no_grad():
script_output = reload_script(dummy_input)
print('[Info] script_output: {}'.format(script_output))
print('[Info] 验证 is equal: {}'.format(F.l1_loss(standard_out, script_output)))
print('[Info] 存储完成: {}'.format(pt_path))
全部转换和验证PT模型的逻辑,都位于save_pt()
函数中,调用即可生成,输出位于pt_models
文件夹中,即:
me.save_pt(os.path.join(DATA_DIR, "pt_models"))
输出的模型是:model_best_c2_20210915_cpu.pt
,GPU版本是:model_best_c2_20210915_cuda.pt
。
在pytorch-image-models-my工程中,pth.tar模型转换为PT模型的转换脚本,源码如下,参考model_2_pt_script.py:
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2021. All rights reserved.
Created by C. L. Wang on 15.9.21
"""
import argparse
import os
import sys
p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
sys.path.append(p)
from root_dir import DATA_DIR
from myscripts.img_predictor import ImgPredictor
def parse_args():
"""
处理脚本参数
"""
parser = argparse.ArgumentParser(description='PyTorch模型转换PT模型')
parser.add_argument('-m', dest='model_path', required=True, help='模型路径', type=str)
parser.add_argument('-n', dest='base_net', required=False, help='basenet', type=str, default="resnet50")
parser.add_argument('-c', dest='num_classes', required=False, help='类别个数', type=int, default=2)
parser.add_argument('-o', dest='out_dir', required=False, help='输出文件夹', type=str,
default=os.path.join(DATA_DIR, "pt_models"))
args = parser.parse_args()
arg_model_path = args.model_path
print("[Info] 模型路径: {}".format(arg_model_path))
arg_base_net = args.base_net
print("[Info] basenet: {}".format(arg_base_net))
arg_num_classes = args.num_classes
print("[Info] 类别数: {}".format(arg_num_classes))
arg_out_dir = args.out_dir
print("[Info] 输出文件夹: {}".format(arg_out_dir))
return arg_model_path, arg_base_net, arg_num_classes, arg_out_dir
def main():
"""
入口函数
"""
print('[Info] ' + "-" * 100)
print('[Info] 转换PT模型开始')
arg_model_path, arg_base_net, arg_num_classes, arg_out_dir = parse_args()
me = ImgPredictor(arg_model_path, arg_base_net, arg_num_classes)
pt_path = me.save_pt(arg_out_dir) # 存储PT模型
print('[Info] 存储完成: {}'.format(pt_path))
print('[Info] ' + "-" * 100)
if __name__ == "__main__":
main()