Win10+libtorch1.1+opencv 笔记

这几天刚刚把libtorch加载模型弄明白,记录一下。
1、正确安装VS2017+opencv+cmake +pytorch 1.1
2、官网下载libtorch cpu 1.1版本(注意pytorch与libtorch版本一致)
3、pytorch 导出模型

import torch
from torchvision import models
model = models.resnet18()

#导入已经训练好的模型
#state = torch.load('latest.pt')
#model.load_state_dict(state['model_state_dict'], strict=True)

#注意模型输入的尺寸
example = torch.rand(1, 3, 224, 224)
model = model.eval()

traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1,3,224,224))
traced_script_module.save("model.pt")

4、cmake 编写

cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(custom_ops)

find_package(Torch REQUIRED)
find_package( OpenCV REQUIRED )

include_directories( ${OpenCV_INCLUDE_DIRS} )

add_executable(example-app example-app.cpp)
target_link_libraries(example-app ${TORCH_LIBRARIES} ${OpenCV_LIBS} )
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)

5、新建build文件夹 并且进入build 打开命令行,这里的Visual Studio 15 Win64是指VS2017

cmake -DCMAKE_PREFIX_PATH=D:\yourpath\opencv\build\x64\vc15\lib;D:\yourpath\libtorch -DCMAKE_BUILD_TYPE=Release -G"Visual Studio 15 Win64" ..

6、打开VS项目sln
7、编写libtorch代码加载模型

#include 
//#include 
#include 
#include 
#include 
#include 

using namespace std;

shared_ptr load_model(string model_path)
{
    shared_ptr module = torch::jit::load(model_path);
    //module->to(device);
    assert(module != nullptr);
    std::cout << "load model ok\n";
    return module;
}




int main(int argc, const char* argv[])
{

    if (argc != 3) 
    {
        cerr << "usage : example-app  ";
        return -1;
    }

    shared_ptr module = load_model(argv[1]);
    
    cv::Mat image = cv::imread(argv[2]);
    
    cvtColor(image, image, cv::COLOR_BGR2RGB);
    cv::Mat img_float;
    image.convertTo(img_float, CV_32F, 1.0 / 255);
    cv::resize(img_float, img_float, cv::Size(224, 224));
    auto img_tensor = torch::from_blob(img_float.data, { 1, 224, 224, 3 });
    img_tensor = img_tensor.permute({ 0, 3, 1, 2 });

    //输入
    std::vector inputs;

    inputs.push_back(img_tensor);
      // evalute time
    double t = (double)cv::getTickCount();
    auto out = module->forward(inputs).toTensor();
    std::cout << out << std::endl;

    t = (double)cv::getTickCount() - t;
    printf("耗费时间为:  %gs\n", t / cv::getTickFrequency());
    inputs.pop_back();

    return 0;
}

8、将c10.dll、caff2.dll、torch.dll、opencv_wordl400.dll放到与exe文件同级目录
9、命令行运行exe文件

你可能感兴趣的:(Win10+libtorch1.1+opencv 笔记)