【AI】PyTorch实战(四):使用C++调用PyTorch模型

1、简述

PyTorch是由python语言实现,想要使用C++调用PyTorch模型,需要先将PyTorch模型的结构和参数保存成一种和具体编程语言无关的格式,然后使用C++来解析这种格式即可。TorchScript程序就是使用python语言,将PyTorch模型的结构和参数保存下来,这一操作称为序列化。

PyTorch提供C++接口库——libTorch,该库用于加载已经序列化的PyTorch模型,这一操作称为反序列化。

2、使用步骤

2.1 生成torch.jit.ScriptModule对象

在序列化之前,首先需要在python环境中产生一个ScriptModule对象,这个对象有执行序列化的接口。
有两种方法产生ScriptModule对象,一种称为==“跟踪”、另一种称为“注解”
1)所谓
“跟踪”就是加载网络模型后,创建一个随机输入,在“特定方法”中,执行一次==预测(向前传播)操作,这个“特定方法”(torch.jit.trace)可以跟踪预测的过程,获取预测时网络模型的结构和参数,将它们封装到ScriptModule对象中。

示例代码如下:

import torch
import torc

你可能感兴趣的:(AI,1024程序员节,pytorch)