Pytorch crnn 笔记(三)

本想自己从头写起,查了一下有人实现过,那我就只剩验证和改善的工作了。

参考博客:Pytorch模型部署 - Libtorch(crnn模型部署)

Step1: 模型转换

将pytorch训练好的crnn模型转换为libtorch能够读取的模型.

#covertion.py
import torch
import torchvison

model = CRNN(32, 1, len(keys.alphabetEnglish) + 1, 256, 1).cpu()

state_dict = torch.load(
    model_path, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # remove `module.`
    new_state_dict[name] = v
# # # load params
model.load_state_dict(new_state_dict)

# convert pth-model to pt-model
example = torch.rand(1, 1, 32, 512)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("src/crnn.pt")

代码过长,github附完整代码。github: crnn_libtorch

Step2: 模型部署

利用libtoch+opencv实现对文字条的识别.

//crnnDeploy.h
#include 
#include 
#include 
#include 

#include 
#include 
#include 

#ifndef CRNN_H
#define CRNN_H

class Crnn{
    public:
        Crnn(std::string& modelFile, std::string& keyFile);
        torch::Tensor loadImg(std::string& imgFile, bool isbath=false);
        void infer(torch::Tensor& input);
    private:
        torch::jit::script::Module m_module;
        std::vector m_keys;
        std::vector readKeys(const std::string& keyFile);
        torch::jit::script::Module loadModule(const std::string& modelFile);
};

#endif//CRNN_H
/*
@author
date: 2020-03-17
Introduce:
    Deploy crnn model with libtorch.
*/

#include "CrnnDeploy.h"
#include 
#include 

//construtor
Crnn::Crnn(std::string& modelFile, std::string& keyFile){
    this->m_module = this->loadModule(modelFile);
    this->m_keys = this->readKeys(keyFile);
}


torch::Tensor Crnn::loadImg(std::string& imgFile, bool isbath){
	cv::Mat input = cv::imread(imgFile, 0);
	if(!input.data){
		printf("Error: not image data, imgFile input wrong!!");
	}
	int resize_h = int(input.cols * 32 / input.rows);
	cv::resize(input, input, cv::Size(resize_h, 32));
    torch::Tensor imgTensor;
    if(isbath){
        imgTensor = torch::from_blob(input.data, {32, resize_h, 1}, torch::kByte);
	    imgTensor = imgTensor.permute({2,0,1});
    }else
    {
        imgTensor = torch::from_blob(input.data, {1,32, resize_h, 1}, torch::kByte);
        imgTensor = imgTensor.permute({0,3,1,2});
    }
	imgTensor = imgTensor.toType(torch::kFloat);
	imgTensor = imgTensor.div(255);
	imgTensor = imgTensor.sub(0.5);
	imgTensor = imgTensor.div(0.5);
    return imgTensor;
}

void Crnn::infer(torch::Tensor& input){
    torch::Tensor output = this->m_module.forward({input}).toTensor();
    std::vector predChars;
    int numImgs = output.sizes()[1];
    if(numImgs == 1){
        for(uint i=0; i(maxRes).item();
            predChars.push_back(maxIdx);
        }
        // 字符转录处理
        std::string realChars="";
        for(uint i=0; i0 && predChars[i-1]==predChars[i])){
                    realChars += this->m_keys[predChars[i]];
                }
            }
        }
        std::cout << realChars << std::endl;
    }else
    {
        std::vector realCharLists;
        std::vector> predictCharLists;

        for (int i=0; i temp;
            for(int j=0; j(max_result).item();//predict value
                temp.push_back(max_index);
            }
            predictCharLists.push_back(temp);
        }

        for(auto vec : predictCharLists){
            std::string text = "";
            for(uint i=0; i0 && vec[i-1]==vec[i])){
                        text += this->m_keys[vec[i]];
                    }
                }
            }
            realCharLists.push_back(text);
        }
        for(auto t : realCharLists){
            std::cout << t << std::endl;
        }
    }

}

std::vector Crnn::readKeys(const std::string& keyFile){
    std::ifstream in(keyFile);
	std::ostringstream tmp;
	tmp << in.rdbuf();
	std::string keys = tmp.str();

    std::vector words;
    words.push_back(" ");//函数过滤掉了第一个空格,这里加上
    int len = keys.length();
    int i = 0;
    
    while (i < len) {
      assert ((keys[i] & 0xF8) <= 0xF0);
      int next = 1;
      if ((keys[i] & 0x80) == 0x00) {
      } else if ((keys[i] & 0xE0) == 0xC0) {
        next = 2;
      } else if ((keys[i] & 0xF0) == 0xE0) {
        next = 3;
      } else if ((keys[i] & 0xF8) == 0xF0) {
        next = 4;
      }
      words.push_back(keys.substr(i, next));
      i += next;
    } 
    return words;
}

torch::jit::script::Module Crnn::loadModule(const std::string& modelFile){
    torch::jit::script::Module module;
    try{
         module = torch::jit::load(modelFile);
    }catch(const c10::Error& e){
        std::cerr << "error loadding the model !!!\n";
    }
    return module;
}


long getCurrentTime(void){
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return tv.tv_sec * 1000 + tv.tv_usec/1000;
}

int main(int argc, const char* argv[]){

    if(argc<4){
        printf("Error use CrnnDeploy: loss input param !!! \n");
        return -1;
    }
    std::string modelFile = argv[1];
    std::string keyFile = argv[2];
    std::string imgFile = argv[3];

    long t1 = getCurrentTime();
    Crnn* crnn = new Crnn(modelFile,keyFile);
    torch::Tensor input = crnn->loadImg(imgFile);
    crnn->infer(input);
    delete crnn;
    long t2 = getCurrentTime();

    printf("ocr time : %ld ms \n", (t2-t1));
    return 0;
}

验证结果(图片与识别结果):

Pytorch crnn 笔记(三)_第1张图片

 

 

 

 

 

 

你可能感兴趣的:(文字识别,pytorch,C++)