一. 基本环境设置及测试
1.安装win版本cmake
2.官网下载支持win版本的pytorch文件https://pytorch.org/getstarted/locally/
3.cmake 文件CMakeLists.txt
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
4.测试代码 example-app.cpp
#include
#include
int main() {
torch::Tensor tensor = torch::rand({2, 3});
std::cout << tensor << std::endl;
}
5.CMakeLists.txt和 example-app.cpp文件夹下创建build文件夹,在这个文件夹下运行
cmake -G "Visual Studio 15 2017" -A x64 -DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch ..
-DCMAKE_PREFIX_PATH这个路径要设置正确,同时要制定使用64位,cmake默认的使用32位。
6.使用vs打开项目,将example-app项目设置位默认启动项目,生成exe文件。
7.将/absolute/path/to/libtorch/lib下的所有dll文件都放在exe文件夹下,就可以得到结果。
二.ResNet-18分类测试(opencv和pytorch一起使用)
0. 官方下载win下编译好的opencv就行
1.使用python代码获得模型
import torch
import torchvision
# pretrained=True 下载模型到本路径
model = torchvision.models.resnet18(pretrained=True)
# Evaluation mode
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# Save traced model
traced_script_module.save("resnet_model.pth")
自定义模型输出:
import torch
import torch.nn as nn
import torchvision
import medvision as mv
def resnet34(num_classes, pretrained=True):
model = torchvision.models.resnet34(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
model = torchvision.models.resnet34(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
mv.load_checkpoint(model, 'E:/pytorch_c++/build/fxfilter.pth')
# model.load_state_dict(torch.load('E:/pytorch_c++/build/fxfilter.pth', map_location='cpu'))
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("resnet_model.pth")
2.c++inferance代码
3.编写CMakeLists.txt
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(test)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
if(NOT Torch_FOUND)
message(FATAL_ERROR "Pytorch Not Found!")
endif(NOT Torch_FOUND)
message(STATUS "Pytorch status:")
message(STATUS " libraries: ${TORCH_LIBRARIES}")
message(STATUS "OpenCV library status:")
message(STATUS " version: ${OpenCV_VERSION}")
message(STATUS " libraries: ${OpenCV_LIBS}")
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
add_executable(test example-app.cpp)
target_link_libraries(test ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET test PROPERTY CXX_STANDARD 11)
4.cmake生成项目依赖
cmake -G "Visual Studio 15 2017" -A x64 -DCMAKE_PREFIX_PATH="D:\Program Files\opencv\build\x64\vc15\lib";E:\pytorch_c++\libtorch-win-shared-with-deps-latest\libtorch ..
(1)cmake默认得到的是32位,而libtorch中的文件是64位
(2)DCMAKE_PREFIX_PATH设置opencv和libtorch的路径,如果opencv的绝对路径中有空格,比如D:\Program Files\opencv\build\x64\vc15\lib中Program Files中间有个空格,会找不到这个路径,给绝对路径加上双引号就能解决。
5. 使用vs生成,libtorch下载的是release版本就在release模式下,否则在读入模型时会报错。
cmake编写 pytorch opencv配置
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(warp_perspective)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
# Define our library target
add_library(warp_perspective SHARED op.cpp)
# Enable C++11
target_compile_features(warp_perspective PRIVATE cxx_range_for)
# Link against LibTorch
target_link_libraries(warp_perspective "${TORCH_LIBRARIES}")
# Link against OpenCV
target_link_libraries(warp_perspective opencv_core opencv_imgproc)