这几天刚刚把libtorch加载模型弄明白,记录一下。
1、正确安装VS2017+opencv+cmake +pytorch 1.1
2、官网下载libtorch cpu 1.1版本(注意pytorch与libtorch版本一致)
3、pytorch 导出模型
import torch
from torchvision import models
model = models.resnet18()
#导入已经训练好的模型
#state = torch.load('latest.pt')
#model.load_state_dict(state['model_state_dict'], strict=True)
#注意模型输入的尺寸
example = torch.rand(1, 3, 224, 224)
model = model.eval()
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1,3,224,224))
traced_script_module.save("model.pt")
4、cmake 编写
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
find_package( OpenCV REQUIRED )
include_directories( ${OpenCV_INCLUDE_DIRS} )
add_executable(example-app example-app.cpp)
target_link_libraries(example-app ${TORCH_LIBRARIES} ${OpenCV_LIBS} )
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
5、新建build文件夹 并且进入build 打开命令行,这里的Visual Studio 15 Win64是指VS2017
cmake -DCMAKE_PREFIX_PATH=D:\yourpath\opencv\build\x64\vc15\lib;D:\yourpath\libtorch -DCMAKE_BUILD_TYPE=Release -G"Visual Studio 15 Win64" ..
6、打开VS项目sln
7、编写libtorch代码加载模型
#include
//#include
#include
#include
#include
#include
using namespace std;
shared_ptr load_model(string model_path)
{
shared_ptr module = torch::jit::load(model_path);
//module->to(device);
assert(module != nullptr);
std::cout << "load model ok\n";
return module;
}
int main(int argc, const char* argv[])
{
if (argc != 3)
{
cerr << "usage : example-app ";
return -1;
}
shared_ptr module = load_model(argv[1]);
cv::Mat image = cv::imread(argv[2]);
cvtColor(image, image, cv::COLOR_BGR2RGB);
cv::Mat img_float;
image.convertTo(img_float, CV_32F, 1.0 / 255);
cv::resize(img_float, img_float, cv::Size(224, 224));
auto img_tensor = torch::from_blob(img_float.data, { 1, 224, 224, 3 });
img_tensor = img_tensor.permute({ 0, 3, 1, 2 });
//输入
std::vector inputs;
inputs.push_back(img_tensor);
// evalute time
double t = (double)cv::getTickCount();
auto out = module->forward(inputs).toTensor();
std::cout << out << std::endl;
t = (double)cv::getTickCount() - t;
printf("耗费时间为: %gs\n", t / cv::getTickFrequency());
inputs.pop_back();
return 0;
}
8、将c10.dll、caff2.dll、torch.dll、opencv_wordl400.dll放到与exe文件同级目录
9、命令行运行exe文件