pytorch的部署——把pytorch模型集成到so库

前沿

需要pytorch1.0版本及以上,linux环境下进行,借助cmake编译

1.编写保存pytorch模型代码

//main.py
import torch
class MyModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        if bool(input.sum() > 0):
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_script_module = MyModule(2, 3)
my_script_module.save("model.pt")

执行python main.py得到模型

2. 编写被打包的so代码

//test.cpp
#include  // One-stop header.

#include 
#include 
#include "test.h"
//arg为传入的模型name
void test(char* arg) {
  // Deserialize the ScriptModule from a file using torch::jit::load().
  std::shared_ptr module = torch::jit::load(arg);
  std::cout << "hello pytorch lib,ok\n";
}

3. 编写cmake配置文件

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops2)

find_package(Torch REQUIRED)
#生成可执行文件
#add_executable(example-app example-app.cpp)
#生成so库
add_library( # Sets the name of the library.
             pytorchtest
             # Sets the library as a shared library.
             SHARED
             # Provides a relative path to your source file(s).
             test.cpp)

target_link_libraries(pytorchtest "${TORCH_LIBRARIES}")
set_property(TARGET pytorchtest PROPERTY CXX_STANDARD 11)

4. 编译

在当前目录新建build文件夹,然后编译,如下:

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/XXX/libtorch ..
make

这里的/XXX/libtorch路径是在pytorch官网下载的c++库解压后的路径,因为前面建立的模型没有用到GPU,这里下载的是linux中的cpu版本:
pytorch的部署——把pytorch模型集成到so库_第1张图片

5. 调用共享库so

前面步骤将在build目录下生成库文件libpytorchtest.so,调用so库需要配置找到该库文件的路径,这里把so放到~/lib中,配置环境变量:
export LD_LIBRARY_PATH=/home/haward/lib:$LD_LIBRARY_PATH
接着编写main.c去调用so

//main.c
#include "test.h"
int main()
{
	test("model.pt");
    return 0;		
}

编译main.c的命令:
g++ main.cpp -L ~/lib -lpytorchtest -o main
其中,-L表示链接的so存放的路径;-l接省略lib…so的命名(-Iinclude表示找到头文件include目录,这里在.目录下,可省略)

得到main可执行文件:执行./main

输出:"hello pytorch lib,ok\n"表示加载模型成功。
参考:
https://pytorch.org/tutorials/advanced/cpp_export.html
代码下载

你可能感兴趣的:(pytorch的部署——把pytorch模型集成到so库)