Pytorch模型部署 - Libtorch(crnn模型部署)

Pytorch模型部署 - Libtorch

简介

libtorch是facebook提供的一套C++推理接口库,便于工业级别部署和性能优化。

配置

  • cmake 3.0
  • libtorch-1.14(cpu)
  • opencv-4.1.1

安装:

libtoch+opencv联合编译,这里采用libtorch-1.4(cpu)+opencv4.1.

  • 可能出现的问题

    • ibtoch,opencv联合编译项目时,报错Undefined reference to cv::imread(std::string const&, int).
    • 解决方案:
      • 在相同编译环境下,重新编译libtorch和opencv源码.(未测试…)
      • 在opencv的CMakeList.txt中加上add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)重新编译opencv.(测试通过)
  • libtorch安装:解压下载包就可以,在代码编译时指定库的路径即可。

  • opencv安装: 下载源码 https://opencv.org/releases/

    unzip opencv-4.1.1.zip
    cd opencv-4.1.1
    # vim CMakeList.txt 如果出现上面问题,在这里添加上述命令,重新编译安装
    mkdir build && cd build
    cmake -D CMAKE_BUILD_TYPE=RELEASE -D OPENCV_GENERATE_PKGCONFIG=ON -D CMAKE_INSTALL_PREFIX=/usr/local ..
    make -j4
    sudo make intall
    

    ls /usr/local/lib查看安装好的opencv库.

案例:libtorch部署crnn-英文识别模型.

crnn: 文本识别模型,常用于OCR.

Step1: 模型转换

将pytorch训练好的crnn模型转换为libtorch能够读取的模型.

#covertion.py
import torch
import torchvison

model = CRNN(32, 1, len(keys.alphabetEnglish) + 1, 256, 1).cpu()

state_dict = torch.load(
    model_path, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # remove `module.`
    new_state_dict[name] = v
# # # load params
model.load_state_dict(new_state_dict)

# convert pth-model to pt-model
example = torch.rand(1, 1, 32, 512)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("src/crnn.pt")

代码过长,github附完整代码。github: crnn_libtorch

Step2: 模型部署

利用libtoch+opencv实现对文字条的识别.

//crnnDeploy.h
#include 
#include 
#include 
#include 

#include 
#include 
#include 

#ifndef CRNN_H
#define CRNN_H

class Crnn{
    public:
        Crnn(std::string& modelFile, std::string& keyFile);
        torch::Tensor loadImg(std::string& imgFile, bool isbath=false);
        void infer(torch::Tensor& input);
    private:
        torch::jit::script::Module m_module;
        std::vector m_keys;
        std::vector readKeys(const std::string& keyFile);
        torch::jit::script::Module loadModule(const std::string& modelFile);
};

#endif//CRNN_H
/*
@author
date: 2020-03-17
Introduce:
    Deploy crnn model with libtorch.
*/

#include "CrnnDeploy.h"
#include 
#include 

//construtor
Crnn::Crnn(std::string& modelFile, std::string& keyFile){
    this->m_module = this->loadModule(modelFile);
    this->m_keys = this->readKeys(keyFile);
}


torch::Tensor Crnn::loadImg(std::string& imgFile, bool isbath){
	cv::Mat input = cv::imread(imgFile, 0);
	if(!input.data){
		printf("Error: not image data, imgFile input wrong!!");
	}
	int resize_h = int(input.cols * 32 / input.rows);
	cv::resize(input, input, cv::Size(resize_h, 32));
    torch::Tensor imgTensor;
    if(isbath){
        imgTensor = torch::from_blob(input.data, {32, resize_h, 1}, torch::kByte);
	    imgTensor = imgTensor.permute({2,0,1});
    }else
    {
        imgTensor = torch::from_blob(input.data, {1,32, resize_h, 1}, torch::kByte);
        imgTensor = imgTensor.permute({0,3,1,2});
    }
	imgTensor = imgTensor.toType(torch::kFloat);
	imgTensor = imgTensor.div(255);
	imgTensor = imgTensor.sub(0.5);
	imgTensor = imgTensor.div(0.5);
    return imgTensor;
}

void Crnn::infer(torch::Tensor& input){
    torch::Tensor output = this->m_module.forward({input}).toTensor();
    std::vector predChars;
    int numImgs = output.sizes()[1];
    if(numImgs == 1){
        for(uint i=0; i(maxRes).item();
            predChars.push_back(maxIdx);
        }
        // 字符转录处理
        std::string realChars="";
        for(uint i=0; i0 && predChars[i-1]==predChars[i])){
                    realChars += this->m_keys[predChars[i]];
                }
            }
        }
        std::cout << realChars << std::endl;
    }else
    {
        std::vector realCharLists;
        std::vector> predictCharLists;

        for (int i=0; i temp;
            for(int j=0; j(max_result).item();//predict value
                temp.push_back(max_index);
            }
            predictCharLists.push_back(temp);
        }

        for(auto vec : predictCharLists){
            std::string text = "";
            for(uint i=0; i0 && vec[i-1]==vec[i])){
                        text += this->m_keys[vec[i]];
                    }
                }
            }
            realCharLists.push_back(text);
        }
        for(auto t : realCharLists){
            std::cout << t << std::endl;
        }
    }

}

std::vector Crnn::readKeys(const std::string& keyFile){
    std::ifstream in(keyFile);
	std::ostringstream tmp;
	tmp << in.rdbuf();
	std::string keys = tmp.str();

    std::vector words;
    words.push_back(" ");//函数过滤掉了第一个空格,这里加上
    int len = keys.length();
    int i = 0;
    
    while (i < len) {
      assert ((keys[i] & 0xF8) <= 0xF0);
      int next = 1;
      if ((keys[i] & 0x80) == 0x00) {
      } else if ((keys[i] & 0xE0) == 0xC0) {
        next = 2;
      } else if ((keys[i] & 0xF0) == 0xE0) {
        next = 3;
      } else if ((keys[i] & 0xF8) == 0xF0) {
        next = 4;
      }
      words.push_back(keys.substr(i, next));
      i += next;
    } 
    return words;
}

torch::jit::script::Module Crnn::loadModule(const std::string& modelFile){
    torch::jit::script::Module module;
    try{
         module = torch::jit::load(modelFile);
    }catch(const c10::Error& e){
        std::cerr << "error loadding the model !!!\n";
    }
    return module;
}


long getCurrentTime(void){
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return tv.tv_sec * 1000 + tv.tv_usec/1000;
}

int main(int argc, const char* argv[]){

    if(argc<4){
        printf("Error use CrnnDeploy: loss input param !!! \n");
        return -1;
    }
    std::string modelFile = argv[1];
    std::string keyFile = argv[2];
    std::string imgFile = argv[3];

    long t1 = getCurrentTime();
    Crnn* crnn = new Crnn(modelFile,keyFile);
    torch::Tensor input = crnn->loadImg(imgFile);
    crnn->infer(input);
    delete crnn;
    long t2 = getCurrentTime();

    printf("ocr time : %ld ms \n", (t2-t1));
    return 0;
}

完整代码和测试模型:
github: crnn_libtorch

获取代码: git clone https://github.com/chenyangMl/crnn_libtorch.git

参考

  • opencv installtion: https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html
  • libtorch : https://pytorch.org/tutorials/advanced/cpp_frontend.html

你可能感兴趣的:(libtorch)