C++:Windows平台下利用LibTorch调用PyTorch模型

文章目录

  • 环境
  • Libtorch下载
  • Pytorch将.pth转为.pt文件
    • python环境下的预测
      • 输出结果:rose
    • 新建pt模型生成文件
      • 输出结果:rose
  • C++调用pytorch模型
    • 新建空项目pt_alex
    • 项目属性配置
      • 修改配置管理器
      • 属性>VC++目录>包含目录
      • 属性>VC++目录>库目录
      • 属性>链接器>输入>附加依赖项
      • 注意CUDA下的情况
      • 属性>C/C++
    • 项目下新建test.cpp
      • 输出结果:rose
  • C# Demo
    • 新建C++空项目,封装DLL
      • 源码
      • 项目属性
      • 点击生成解决方案,生成DLL
    • 新建C#窗体应用
      • DLLFun.cs
      • 窗体Form2.cs核心代码
      • 结果

参考:C++调用PyTorch模型:LibTorch

环境

Windows10
VS2017
CPU

OpenCV3.0.0

Pytorch1.10.2  torchvision0.11.3
Libtorch1.10.2

Libtorch下载

Pytorch官网

C++:Windows平台下利用LibTorch调用PyTorch模型_第1张图片
解压后:注意红框文件夹路径,之后需要添加到项目属性配置中。
C++:Windows平台下利用LibTorch调用PyTorch模型_第2张图片

Pytorch将.pth转为.pt文件

所使用的模型为基于AlexNet的分类模型:AlexNet:论文阅读及pytorch网络搭建

python环境下的预测

输出结果:rose

C++:Windows平台下利用LibTorch调用PyTorch模型_第3张图片
C++:Windows平台下利用LibTorch调用PyTorch模型_第4张图片

新建pt模型生成文件

# tmp.py

import os
import torch
from PIL import Image
from torchvision import transforms
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # create model
    model = AlexNet(num_classes=5).to(device)

    image = Image.open("rose2.jpg").convert('RGB')
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img = data_transform(image)
    img = img.unsqueeze(dim=0)
    print(img.shape)

    # load model weights
    weights_path = "AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)

    testsize = 224

    if torch.cuda.is_available():
        modelState = torch.load(weights_path, map_location='cuda')
        model.load_state_dict(modelState, strict=False)
        model = model.cuda()
        model = model.eval()
        # An example input you would normally provide to your model's forward() method.
        example = torch.rand(1, 3, testsize, testsize)
        example = example.cuda()
        traced_script_module = torch.jit.trace(model, example)

        output = traced_script_module(img.cuda())
        print(output.shape)
        pred = torch.argmax(output, dim=1)
        print(pred)

        traced_script_module.save('model_cuda.pt')
    else:
        modelState = torch.load(weights_path, map_location='cpu')
        model.load_state_dict(modelState, strict=False)
        example = torch.rand(1, 3, testsize, testsize)
        example = example.cpu()
        traced_script_module = torch.jit.trace(model, example)

        output = traced_script_module(img.cpu())
        print(output.shape)
        pred = torch.argmax(output, dim=1)
        print(pred)

        traced_script_module.save('model.pt')

if __name__ == '__main__':
    main()

输出结果:rose

C++:Windows平台下利用LibTorch调用PyTorch模型_第5张图片

C++:Windows平台下利用LibTorch调用PyTorch模型_第6张图片

C++调用pytorch模型

新建空项目pt_alex

C++:Windows平台下利用LibTorch调用PyTorch模型_第7张图片

项目属性配置

修改配置管理器

Release/x64
在这里插入图片描述

属性>VC++目录>包含目录

添加:(libtorch解压位置)
C++:Windows平台下利用LibTorch调用PyTorch模型_第8张图片

注意还应有opencv目录:(继承值修改可参考)
在这里插入图片描述

属性>VC++目录>库目录

添加:
C++:Windows平台下利用LibTorch调用PyTorch模型_第9张图片

属性>链接器>输入>附加依赖项

添加:
C++:Windows平台下利用LibTorch调用PyTorch模型_第10张图片

注意:
如果后续出现error:找不到c10.dll,
直接把该目录下的相应dll复制到项目pt_alex/x64/Release文件夹下。

注意还应有opencv目录:(Debug下为lib*d.lib)
在这里插入图片描述

注意CUDA下的情况

链接器>命令行,添加:

/INCLUDE:?warp_size@cuda@at@@YAHXZ

属性>C/C++

常规>SDL检查:选择否
语言>符合模式:选择否

项目下新建test.cpp

c++调用后分类结果不准确的参考:
Ptorch 与libTorch 使用过程中问题记录
注意python和c++中的图像预处理过程需要完全一致。

// test.cpp

#include  // One-stop header.
#include "torch/torch.h"
#include 
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgcodecs.hpp"
#include 
#include 
#include 
#include 
#include 
#include 

// class_list
/*
	"0": "daisy",
	"1": "dandelion",
	"2": "roses",
	"3": "sunflowers",
	"4": "tulips"
*/

std::string classList[5] = { "daisy", "dandelion", "rose", "sunflower", "tulip" };

std::string image_path = "rose2.jpg";

int main(int argc, const char* argv[]) {

	// Deserialize the ScriptModule from a file using torch::jit::load().
	//std::shared_ptr module = torch::jit::load("../../model_resnet_jit.pt");
	using torch::jit::script::Module;
	Module module = torch::jit::load("model.pt");

	std::cout << "测试图片:" << image_path << std::endl;

	std::cout << "cuda support:" << (torch::cuda::is_available() ? "ture" : "false") << std::endl;
	std::cout << "CUDNN:  " << torch::cuda::cudnn_is_available() << std::endl;
	std::cout << "GPU(s): " << torch::cuda::device_count() << std::endl;

	// module.to(at::kCUDA); //cpu下会在(auto image = cv::imread(image_path, cv::IMREAD_COLOR))行引起c10:error,未经处理的异常
	module.eval();
	module.to(at::kCPU);

	//assert(module != nullptr);
	//std::cout << "ok\n";

	//输入图像
	auto image = cv::imread(image_path, cv::IMREAD_COLOR);
	cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
	cv::Mat image_transfomed = cv::Mat(cv::Size(224, 224), image.type());
	cv::resize(image, image_transfomed, cv::Size(224, 224));

	//cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);

	// 转换为Tensor
	torch::Tensor tensor_image = torch::from_blob(image_transfomed.data,
		{ image_transfomed.rows, image_transfomed.cols,3 }, torch::kByte);
	tensor_image = tensor_image.permute({ 2,0,1 });
	tensor_image = tensor_image.toType(torch::kFloat);
	auto tensor_image_Tmp = torch::autograd::make_variable(tensor_image, false);
	tensor_image = tensor_image.div(255);
	tensor_image = tensor_image.unsqueeze(0);
	// tensor_image = tensor_image.to(at::kCUDA);
	tensor_image = tensor_image.to(at::kCPU);

	// 网络前向计算
	at::Tensor output = module.forward({ tensor_image }).toTensor();
	std::cout << "output:" << output << std::endl;

	auto prediction = output.argmax(1);
	std::cout << "prediction:" << prediction << std::endl;

	int maxk = 5;
	auto top3 = std::get<1>(output.topk(maxk, 1, true, true));

	std::cout << "top3: " << top3 << '\n';

	std::vector<int> res;
	for (auto i = 0; i < maxk; i++) {
		res.push_back(top3[0][i].item().toInt());
	}
	// for (auto i : res) {
	// 	std::cout << i << " ";
	// }
	// std::cout << "\n";

	int pre = torch::Tensor(prediction).item<int>();
	std::string result = classList[pre];
	std::cout << "This is:" << result << std::endl;

	cvWaitKey();

	return 0;
	// system("pause");
}

出现以下报错不影响项目生成:
C++:Windows平台下利用LibTorch调用PyTorch模型_第11张图片

输出结果:rose

C++:Windows平台下利用LibTorch调用PyTorch模型_第12张图片

C# Demo

新建C++空项目,封装DLL

  • 传入图像路径
  • 传出类别序号
  • // class_list
    /*
    “0”: “daisy”,
    “1”: “dandelion”,
    “2”: “roses”,
    “3”: “sunflowers”,
    “4”: “tulips”
    */

源码

// test.cpp

#include  // One-stop header.
#include "torch/torch.h"
#include 
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgcodecs.hpp"
#include 
#include 
#include 
#include 
#include 
#include 

#include "test.h"

std::string classList[5] = { "daisy", "dandelion", "rose", "sunflower", "tulip" };

int TestAlex(char* img)
{
	// Deserialize the ScriptModule from a file using torch::jit::load().
	//std::shared_ptr module = torch::jit::load("../../model_resnet_jit.pt");
	using torch::jit::script::Module;
	Module module = torch::jit::load("D:/model.pt");

	// ...... 略

	int pre = torch::Tensor(prediction).item<int>();
	std::string result = classList[pre];
	//std::cout << "This is:" << result << std::endl;

	return pre;
	// system("pause");
}
//test.h

#pragma once

#include 
#include 

extern "C" __declspec(dllexport) int TestAlex(char* img);

项目属性

  • 输出目录改为C#项目路径
    C++:Windows平台下利用LibTorch调用PyTorch模型_第13张图片

点击生成解决方案,生成DLL

新建C#窗体应用

DLLFun.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Runtime.InteropServices;

namespace AlexDemo
{
    class DllFun
    {
        public string img;

        [DllImport("AlexDLL.dll", CallingConvention = CallingConvention.Cdecl)]
        public extern static int TestAlex(string img); // 注意 C++ char* 对应 C# string
    }
}

窗体Form2.cs核心代码

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;
using System.IO;
using AlexDemo;

namespace Demo
{
    public partial class Form2 : Form
    {
        public Form2()
        {
            InitializeComponent();
            this.Load += new EventHandler(Form2_Load); //窗体启动后自动执行事件
        }

        private void Form2_Load(object sender, EventArgs e)
        {
            string[] classList = { "daisy", "dandelion", "rose", "sunflower", "tulip" };
            string fname = "path.txt";
            //StreamReader sr = new StreamReader(fname, Encoding.Default);
            StreamReader sr = new StreamReader(fname, Encoding.GetEncoding("gb2312"));
            string line = sr.ReadLine();
            //读取txt文件
            if (line != null)
            {
                this.pictureBox1.Image = Image.FromFile(line);
                if (line.Contains("\\"))
                {
                    line = line.Replace("\\", "/");
                }
            }
            int result;
            result = DllFun.TestAlex(line);
            //string r = result.ToString();
            label2.Text = classList[result];

            //StringBuilder img;
            //img = new StringBuilder(1024);
            //img.Append(line);
            //int r = DllFun.TestAlex(img);
            //label2.Text = "123";
        }

        private void label1_Click(object sender, EventArgs e)
        {

        }

    }
}

结果

C++:Windows平台下利用LibTorch调用PyTorch模型_第14张图片

你可能感兴趣的:(C/C++/C#,c++,libtorch)