这几天学习C++部署深度模型,看到libtorch是现在人人都可以用普遍方法
linux上安装库或者程序,一个是编译安装、一个是apt安装,这里选择固定版本用编译安装
https://github.com/opencv/opencv/tags?after=4.5.0
https://cmake.org/download/
1 我按pytorch官方示例的wget的cpu版本,加压就行,因为是直接编译好的库文件
https://pytorch.org/cppdocs/installing.html
2 用的代码是之前看到知乎上的分类推理resnet18.pt
https://link.zhihu.com/?target=https%3A//github.com/BIGBALLON/PyTorch-CPP
https://blog.csdn.net/weixin_44523062/article/details/120132110
将下面代码在clion中打开,配置cmakelist 和调试参数 restnet18.pt label.txt
生成可执行程序加入参数运行 就 ok 啦
// One-stop header.
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define kIMAGE_SIZE 224
#define kCHANNELS 3
#define kTOP_K 3
bool LoadImage(std::string file_name, cv::Mat &image) {
image = cv::imread(file_name); // CV_8UC3
if (image.empty() || !image.data) {
return false;
}
cv::namedWindow("d");
cv::imshow("d",image);
cv::waitKey();
cv::destroyAllWindows();
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
std::cout << "== image size: " << image.size() << " ==" << std::endl;
// scale image to fit
cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE);
cv::resize(image, image, scale);
std::cout << "== simply resize: " << image.size() << " ==" << std::endl;
// convert [unsigned int] to [float]
image.convertTo(image, CV_32FC3, 1.0f / 255.0f);
return true;
}
bool LoadImageNetLabel(std::string file_name,
std::vector<std::string> &labels) {
std::ifstream ifs(file_name);
if (!ifs) {
return false;
}
std::string line;
while (std::getline(ifs, line)) {
labels.push_back(line);
}
return true;
}
int main(int argc, const char *argv[]) {
if (argc != 3) {
std::cerr << "Usage: classifier "
""
<< std::endl;
return -1;
}
torch::jit::script::Module module = torch::jit::load(argv[1]);
std::cout << "== Switch to GPU mode" << std::endl;
// to GPU
// module.to(at::kCUDA);
module.to(at::kCPU);
std::cout << "== Model [" << argv[1] << "] loaded!\n";
std::vector<std::string> labels;
if (LoadImageNetLabel(argv[2], labels)) {
std::cout << "== Label loaded! Let's try it\n";
} else {
std::cerr << "Please check your label file path." << std::endl;
return -1;
}
std::string file_name = "";
cv::Mat image;
while (true) {
std::cout << "== Input image path: [enter Q to exit]" << std::endl;
std::cin >> file_name;
if (file_name == "Q") {
break;
}
if (LoadImage(file_name, image)) {
auto input_tensor = torch::from_blob(
image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS});
input_tensor = input_tensor.permute({0, 3, 1, 2});
input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);
input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);
input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);
// to GPU
// input_tensor = input_tensor.to(at::kCUDA);
input_tensor = input_tensor.to(at::kCPU);
torch::Tensor out_tensor = module.forward({input_tensor}).toTensor();
auto results = out_tensor.sort(-1, true);
auto softmaxs = std::get<0>(results)[0].softmax(0);
auto indexs = std::get<1>(results)[0];
for (int i = 0; i < kTOP_K; ++i) {
auto idx = indexs[i].item<int>();
std::cout << " ============= Top-" << i + 1
<< " =============" << std::endl;
std::cout << " Label: " << labels[idx] << std::endl;
std::cout << " With Probability: "
<< softmaxs[i].item<float>() * 100.0f << "%" << std::endl;
}
} else {
std::cout << "Can't load the image, please check your path." << std::endl;
}
}
return 0;
}
cmake_minimum_required(VERSION 3.20)
project(testlibtorch)
set(CMAKE_CXX_STANDARD 14)
set(Torch_DIR /home2/libtorch_apk/libtorch/share/cmake/Torch)
# 寻找OpenCV库
find_package( OpenCV 3 REQUIRED )
# 添加头文件
include_directories( ${OpenCV_INCLUDE_DIRS} )
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(testlibtorch
prediction.cpp)
target_link_libraries(testlibtorch "${TORCH_LIBRARIES}")
target_link_libraries( testlibtorch ${OpenCV_LIBS} )
set_property(TARGET testlibtorch PROPERTY CXX_STANDARD 14)
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
# Don't forget change model to eval mode
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("resnet18.pt")