在 C++ 中加载 TorchScript 模型
推荐阅读官方文档:
如何保存模型
如何在c++中加载模型
pytorch的c++ api
torchScript文档
以下内容基于官方文档写一些注释~~
pytorch模型从python到c++是通过torchScript实现的
跟踪是指通过示例输入进行一次推理并获取模型的结构和参数,并记录下这些输入在模型中的流转。
注释是向模型中添加显式注释,通知编译器对模型代码进行解析。
official code:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18() # 从torchvision导入模型
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224) #创建一个样例输入
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example) #调用torch.jit.trace()来进行跟踪
# 附加部分
traced_script_module.save('xxxx.pt') #直接将跟踪结果进行保存
class MyModule(torch.nn.Module): #示例模型的定义
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20) # 将模型实例化
sm = torch.jit.script(my_module) #使用script
sm.save('xxxxx.pt') #保存为pt文件
最简单的c++程序,用来加载torch脚本
//torch/script.h在libtorch库中
#include // One-stop header.
#include
#include
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app \n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
}
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
然后下载libtorch
libtorch各个版本
官网下载
也可以选择手动编译
在libtorch文件夹中,应当有
此时c++项目目录:
编译构建应用程序
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
如果一切正常,显示如下
root@4b5a67132e81:/example-app# mkdir build
root@4b5a67132e81:/example-app# cd build
root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
测试一下:
./example-app xxxx.pt
正确的输出应该是ok,说明我们的模型已经成功加载。
现在来尝试进行推理:
在main函数的最后一行,添加代码
// Create a vector of inputs. 创建输入
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
// 利用模型进行推导
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
保存文件后使用同样的方法进行测试,得到:
-0.2698 -0.0381 0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]
作为对比,之前python的输出为:
tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=)
如果想基于GPU进行测试,那么在
module = torch::jit::load(argv[1]);
的后面加上
module.to(at::kCUDA)
同时保证其模型的输入也存在于CUDA内存中
tensor.to(at::kCUDA)