本想自己从头写起,查了一下有人实现过,那我就只剩验证和改善的工作了。
参考博客:Pytorch模型部署 - Libtorch(crnn模型部署)
Step1: 模型转换
将pytorch训练好的crnn模型转换为libtorch能够读取的模型.
#covertion.py
import torch
import torchvison
model = CRNN(32, 1, len(keys.alphabetEnglish) + 1, 256, 1).cpu()
state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '') # remove `module.`
new_state_dict[name] = v
# # # load params
model.load_state_dict(new_state_dict)
# convert pth-model to pt-model
example = torch.rand(1, 1, 32, 512)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("src/crnn.pt")
代码过长,github附完整代码。github: crnn_libtorch
Step2: 模型部署
利用libtoch+opencv实现对文字条的识别.
//crnnDeploy.h
#include
#include
#include
#include
#include
#include
#include
#ifndef CRNN_H
#define CRNN_H
class Crnn{
public:
Crnn(std::string& modelFile, std::string& keyFile);
torch::Tensor loadImg(std::string& imgFile, bool isbath=false);
void infer(torch::Tensor& input);
private:
torch::jit::script::Module m_module;
std::vector m_keys;
std::vector readKeys(const std::string& keyFile);
torch::jit::script::Module loadModule(const std::string& modelFile);
};
#endif//CRNN_H
/*
@author
date: 2020-03-17
Introduce:
Deploy crnn model with libtorch.
*/
#include "CrnnDeploy.h"
#include
#include
//construtor
Crnn::Crnn(std::string& modelFile, std::string& keyFile){
this->m_module = this->loadModule(modelFile);
this->m_keys = this->readKeys(keyFile);
}
torch::Tensor Crnn::loadImg(std::string& imgFile, bool isbath){
cv::Mat input = cv::imread(imgFile, 0);
if(!input.data){
printf("Error: not image data, imgFile input wrong!!");
}
int resize_h = int(input.cols * 32 / input.rows);
cv::resize(input, input, cv::Size(resize_h, 32));
torch::Tensor imgTensor;
if(isbath){
imgTensor = torch::from_blob(input.data, {32, resize_h, 1}, torch::kByte);
imgTensor = imgTensor.permute({2,0,1});
}else
{
imgTensor = torch::from_blob(input.data, {1,32, resize_h, 1}, torch::kByte);
imgTensor = imgTensor.permute({0,3,1,2});
}
imgTensor = imgTensor.toType(torch::kFloat);
imgTensor = imgTensor.div(255);
imgTensor = imgTensor.sub(0.5);
imgTensor = imgTensor.div(0.5);
return imgTensor;
}
void Crnn::infer(torch::Tensor& input){
torch::Tensor output = this->m_module.forward({input}).toTensor();
std::vector predChars;
int numImgs = output.sizes()[1];
if(numImgs == 1){
for(uint i=0; i(maxRes).item();
predChars.push_back(maxIdx);
}
// 字符转录处理
std::string realChars="";
for(uint i=0; i0 && predChars[i-1]==predChars[i])){
realChars += this->m_keys[predChars[i]];
}
}
}
std::cout << realChars << std::endl;
}else
{
std::vector realCharLists;
std::vector> predictCharLists;
for (int i=0; i temp;
for(int j=0; j(max_result).item();//predict value
temp.push_back(max_index);
}
predictCharLists.push_back(temp);
}
for(auto vec : predictCharLists){
std::string text = "";
for(uint i=0; i0 && vec[i-1]==vec[i])){
text += this->m_keys[vec[i]];
}
}
}
realCharLists.push_back(text);
}
for(auto t : realCharLists){
std::cout << t << std::endl;
}
}
}
std::vector Crnn::readKeys(const std::string& keyFile){
std::ifstream in(keyFile);
std::ostringstream tmp;
tmp << in.rdbuf();
std::string keys = tmp.str();
std::vector words;
words.push_back(" ");//函数过滤掉了第一个空格,这里加上
int len = keys.length();
int i = 0;
while (i < len) {
assert ((keys[i] & 0xF8) <= 0xF0);
int next = 1;
if ((keys[i] & 0x80) == 0x00) {
} else if ((keys[i] & 0xE0) == 0xC0) {
next = 2;
} else if ((keys[i] & 0xF0) == 0xE0) {
next = 3;
} else if ((keys[i] & 0xF8) == 0xF0) {
next = 4;
}
words.push_back(keys.substr(i, next));
i += next;
}
return words;
}
torch::jit::script::Module Crnn::loadModule(const std::string& modelFile){
torch::jit::script::Module module;
try{
module = torch::jit::load(modelFile);
}catch(const c10::Error& e){
std::cerr << "error loadding the model !!!\n";
}
return module;
}
long getCurrentTime(void){
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000 + tv.tv_usec/1000;
}
int main(int argc, const char* argv[]){
if(argc<4){
printf("Error use CrnnDeploy: loss input param !!! \n");
return -1;
}
std::string modelFile = argv[1];
std::string keyFile = argv[2];
std::string imgFile = argv[3];
long t1 = getCurrentTime();
Crnn* crnn = new Crnn(modelFile,keyFile);
torch::Tensor input = crnn->loadImg(imgFile);
crnn->infer(input);
delete crnn;
long t2 = getCurrentTime();
printf("ocr time : %ld ms \n", (t2-t1));
return 0;
}
验证结果(图片与识别结果):