下载cuda版本为11.3的LibTorch安装包并解压即完成安装:
# If you need cpu support, please replace "cu113" with "cpu" in the URL below.
wget https://download.pytorch.org/libtorch/nightly/cu113/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip
如果需要cpu版本的安装包按照上面注释方法替换即可。
保存tensor的函数:
def save_tensor(data_tensor: Tensor, name: str):
print("[python] %s: "%name, data_tensor)
f = io.BytesIO()
torch.save(data_tensor, f, _use_new_zipfile_serialization=True)
with open('/home/chenxin/peanut/DenseTNT/test_cpp/dense_tnt_infer/data/%s.pt'%name, "wb") as out_f:
# Copy the BytesIO stream to the output file
out_f.write(f.getbuffer())
将多个tensor保存到一个文件,便于管理。
示例:
std::string save_path = "peanut/DenseTNT/test_cpp/dense_tnt_infer/data_atlas/";
torch::save({vector_data.unsqueeze(0),
vector_mask.unsqueeze(0),
traj_polyline_mask.unsqueeze(0).toType(torch::kBool),
map_polyline_mask.unsqueeze(0).toType(torch::kBool)},
save_path + "data.pt");
python通过_parameters依次读取c++保存的tensor:
device = torch.device("cuda:0")
save_path = "peanut/DenseTNT/test_cpp/dense_tnt_infer/data_atlas/"
input_data = torch.jit.load(save_path + "data.pt")
vector_data = input_data._parameters['0'].to(device)
vector_mask = input_data._parameters['1'].bool().to(device)
traj_polyline_mask = input_data._parameters['2'].bool().to(device)
map_polyline_mask = input_data._parameters['3'].bool().to(device)
首先需要将PyTorch模型转换为C++支持的TorchScript模型,具体步骤可参考这里。
C++调用TorchScript模型代码:
#include
#include "torch/script.h"
#include
#include
#include
std::vector<char> get_the_bytes(std::string filename)
{
std::ifstream input(filename, std::ios::binary);
std::vector<char> bytes(
(std::istreambuf_iterator<char>(input)),
(std::istreambuf_iterator<char>()));
input.close();
return bytes;
}
// 加载tensor数据
torch::Tensor GetTensor(const std::string &path)
{
std::vector<char> f = get_the_bytes(path);
torch::IValue x = torch::pickle_load(f);
torch::Tensor my_tensor = x.toTensor();
return my_tensor;
}
int main()
{
torch::Device device(torch::kCPU);
if (torch::cuda::is_available())
{
device = torch::Device(torch::kCUDA, 0);
}
// 读取推理用例数据
torch::Tensor vector_data = GetTensor("test_cpp/dense_tnt_infer/data/vector_data.pt");
torch::Tensor vector_mask = GetTensor("test_cpp/dense_tnt_infer/data/vector_mask.pt");
torch::Tensor traj_polyline_mask = GetTensor("test_cpp/dense_tnt_infer/data/traj_polyline_mask.pt");
torch::Tensor map_polyline_mask = GetTensor("test_cpp/dense_tnt_infer/data/map_polyline_mask.pt");
torch::Tensor cent_point = GetTensor("test_cpp/dense_tnt_infer/data/cent_point.pt");
torch::set_num_threads(1);
std::vector<torch::jit::IValue> torch_inputs;
torch_inputs.push_back(std::move(vector_data.to(device)));
torch_inputs.push_back(std::move(vector_mask.to(device)));
torch_inputs.push_back(std::move(traj_polyline_mask.to(device)));
torch_inputs.push_back(std::move(map_polyline_mask.to(device)));
torch_inputs.push_back(std::move(cent_point.to(device)));
// 加载torch script模型
torch::jit::script::Module torch_script_model = torch::jit::load("models.densetnt.1/model_save/model.16_script.bin", device);
for (int i = 0; i < 100; ++i)
{
auto t1 = std::chrono::high_resolution_clock::now();
auto torch_output_tuple = torch_script_model.forward(torch_inputs);
auto t2 = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::milli> ms_double = t2 - t1;
std::cout << ms_double.count() << "ms\n";
// std::cout << torch_output_tuple << std::endl;
}
return 0;
}
需要安装cuda和cudnn才能进行cmake;
在C++推理代码同级目录下,创建文件 CMakeLists.txt
,写入:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dense_tnt_infer) # 调用模型的c++文件名称
set(CMAKE_PREFIX_PATH "/home/chenxin/libtorch") # 这里填解压libtorch时的路径
find_package(Torch REQUIRED)
add_executable(${PROJECT_NAME} "dense_tnt_infer.cc") # 调用模型的c++文件名称
target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES})
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)
在CMakeLists.txt
同级目录下执行命令:
$ mkdir build
$ cd build
$ cmake ..
$ make
$ ./dense_tnt_infer.cc
输出:
...
...
12.7836ms
12.7527ms
13.0666ms
14.3305ms
13.804ms
14.1567ms
13.3143ms
13.0827ms
13.0853ms
13.5594ms
推理时长还需要通过修改代码再优化。