PyTorch模型转化ONNX在OpenCV上的使用

实现opencv调用ONNX模型,结果能够正常显示。

模型名称:ch_det_server_db_res18.pth,来源:WenmuZhou/PytorchOCR

Step 1:导出模型转化为ONNX模型

参考代码:

model_path = "ch_det_server_db_res18.pth"
onnx_save_path = "model_resnet18.onnx"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.randn(1, 3, 1280, 960, dtype=torch.float, device=device)



ckpt = torch.load(model_path, map_location='cpu')
cfg = ckpt['cfg']
model = build_model(cfg['model'])
state_dict = {}
for k, v in ckpt['state_dict'].items():
    state_dict[k.replace('module.', '')] = v

model.load_state_dict(state_dict)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()


print("Start convert model to onnx...")
torch.onnx.export(model,
                  data,
                  onnx_save_path,
                  opset_version=11,
                  # do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=["input"],  # 输入名
                  output_names=["output"]
                  # dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
                  #               "output": {0: "batch_size"}}

)

print("convert onnx is Done!")

注:在转化ONNX过程中可能出现不能转换,或者不能加载的情况,一般是由于不能加载某个模块造成的,建议百度,GITHUB,stackoverflow查询确定。

Step 2:  ONNX模型优化

python -m onnxsim model_resnet18.onnx model_resnet18-sim.onnx

Step 3: 检测模型是否可用,加载模型,数据标准化

3.1 检测模型是否可用(onnxruntime测试,opencv测试)

import onnxruntime
import torch

export_onnx_file = "model_resnet18_sim.onnx"
img_path = "12.jpg"
image = cv2.imread("12.jpg")  # 读取图片

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

src_img = cv2.imread(img_path, 1)
img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (960, 1280))
img1 = img / 255.0
img1 -= mean
img1 /= std
img1 = img1.transpose(2, 0, 1)
img1 = np.expand_dims(img1, axis=0)
img1 = np.float32(img1)

ort_session = onnxruntime.InferenceSession(export_onnx_file)
for input_meta in ort_session.get_inputs():
    print(input_meta)
for output_meta in ort_session.get_outputs():
    print(output_meta)

ort_inputs = {ort_session.get_inputs()[0].name: img1}
ort_outs = ort_session.run(["output"], ort_inputs)
pred_map = ort_outs[0]
pred_map = torch.from_numpy(pred_map)

3.2 加载模型(opencv)

net = cv2.dnn.readNetFromONNX("model_resnet18_sim.onnx")

3.3 数据预处理

(C++版本, Scalar  mean,std;  float scale;  bool swapRB)

if (rszWidth != 0 && rszHeight != 0)
{
    resize(frame, frame, Size(rszWidth, rszHeight));
}

//! [Create a 4D blob from a frame]
blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, crop);

// Check std values.
if (std.val[0] != 0.0 && std.val[1] != 0.0 && std.val[2] != 0.0)
{
    // Divide blob by std.
    divide(blob, std, blob);
}
//! [Create a 4D blob from a frame]

(Python版本)

cv_std = (1, 0.229, 0.224, 0.225)
cv_mean = (255*0.485, 255*0.456, 255*0.406)
blob = cv2.dnn.blobFromImage(img, scalefactor=1/255.0, mean=cv_mean)
cv2.divide(blob, cv_std, blob)

关键理解函数blobFromImage如下:
OpenCV中的DNN模块包含blobFromImage方法对输入神经网络的图像进行处理:

1.先相对于原图像中心resize,crop

2.再减均值

3.像素值缩放0-255 -> 0-1

4.图像数据通道转换,RGB->BGR,返回一个NCHW 数组

Step 4: 模型推测

net.setInput(blob)
out = net.forward()  

Step 5: 结果后处理(待完成)

你可能感兴趣的:(机器学习,opencv,pytorch,python)