关于C++ libtorch调用pytorch模型的总结

最近接到了一个需求,需要把一个用python基于pytorch实现的DQN强化学习模型移植到Arm平台。 经历了很多坑,最终都解决了,记录一下过程:

环境:
系统:Ubuntu20.04 LTS;
pytorch版本:1.9.0;
python版本:3.8;
libtorch版本:1.9;

准备步骤:
①如何安装python和pytorch请自行百度;
②进入pytorch官网,下载合适版本的libtorch https://pytorch.org/,因为我的需求是移植到Arm平台上,没有GPU以及CUDA,所以选择CPU版本;
关于C++ libtorch调用pytorch模型的总结_第1张图片
③解压缩后,得到libtorch的文件目录,这是官方在X86平台上已经编译好的一些库,参照
libtorch中文文档 https://pytorch.apachecn.org/docs/1.0/cpp_export.html,新建一个自己的目录,比如“example-app”。在“example-app”目录下新建如下文件:
(1)CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

(2)convert.py(将pytorch的模型转换为C++能用的libtorch模型)
此处要注意的是,你需要去看官方教程,不同的模型文件转换方式是不一样的,之前的那个中文文档里有写,此处记录我的模型转换过程。
我的实例:

import torch


model = torch.load('eval_net.pkl')

traced_script_module = torch.jit.script(model)

traced_script_module.save("eval_net.pt")

官方的实例:

import torch
import torchvision

# 获取模型实例
model = torchvision.models.resnet18()

# 生成一个样本供网络前向传播 forward()
example = torch.rand(1, 3, 224, 224)

# 使用 torch.jit.trace 生成 torch.jit.ScriptModule 来跟踪
traced_script_module = torch.jit.trace(model, example)

执行convert.py后生成 eval_net.pt 文件。
(3)example-app.cpp(用来测试转换为C++模型后的程序)

#include  
#include 
#include 

using namespace std;


int main() {
  torch::jit::script::Module module = torch::jit::load("./eval_net.pt");
  vector<torch::jit::IValue> inputs;
  //自定义一组输入数据并设置格式,否则会报错,对应pytorch中的FloatTensor
  inputs.push_back(torch::tensor({0, 1, 2, 3, 4, 6, 0, 0, 1, 0, 2, 1, 5, 6, 2, 3, 4, 6, 0, 5, 1, 3, 2, 1},torch::kFloat));
  //使用模型
  torch::Tensor res = module.forward(inputs).toTensor();
  for (size_t i = 0; i < res.itemsize(); i++)
  {
    if(res[i].equal(res.max())){
      cout << "result:"<< i << endl;
    }
  }
  cout << "加载成功\n";
}

此时example-app目录下文件应为:
在这里插入图片描述

此时在example-app下打开终端,执行语句:

cmake -DCMAKE_PREFIX_PATH=/home/software/libtorch-shared-with-deps-1.9.0+cpu/libtorch

运行后没有提示错误则为成功,目录下会生成Makefile文件。

执行语句:

make

生成example-app可执行程序。
执行语句,运行example-app得到如下提示,说明调用成功:
在这里插入图片描述
至此,X86版本的libtorch已经正常跑通。Arm版本的libtorch相对比较麻烦,需要官网下载pytorch源码,单独交叉编译libtorch,其它的调用方式跟X86是基本一样的,后续再补充。

你可能感兴趣的:(pytorch,c++,python,pytorch)