CRNN-pytorch模型转libtorch模型踩坑记录

这段时间一直在做CRNN文字识别的问题,从pytorch中训练好的模型然后转到libtorch中去,但是CRNN提供的代码没有转libtorch模型的部分,于是就在网上到处乱找,其中找到了这篇转的代码crnn模型转换。感兴趣的可以进去看一下,这个大哥也是从git上找的别人,可能他的训练代码的网络和转换代码的网络是一样的结构,所以它可以转换成功。
我在转换的时候首先是报不能转换成功,然后修改了部分代码之后,能够转换成功,但是在libtorch上推理的时候为啥一点准确度都没有呢,哪怕准确一点就可以啊,然后就开始各种查,检测每个变量时候和pytorch的变量是否一样,不管是图片的预处理还是从网络输出的张量形状,最后还是不行,就寻思是不是转换的时候出了问题,然后就在他的基础上进行了修改,最后可以在自己的网络结构上转换成功;

import numpy as np
import time
import cv2
import torch
from torch.autograd import Variable
import lib.utils.utils as utils
import lib.models.crnn as crnn
import lib.config.alphabets as alphabets
import yaml
from easydict import EasyDict as edict
import os
import random


def deal_cfg(path_cfg):
    with open(path_cfg, 'r') as f:
        config = yaml.load(f)
        config = edict(config)

    config.DATASET.ALPHABETS = alphabets.alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
    print("len(config.DATASET.ALPHABETS):",len(config.DATASET.ALPHABETS))
    return config

if __name__ == '__main__':

    path_cfg = "./lib/config/OWN_config.yaml"
    path_pth = "/root/CRNN_Chinese_Characters_Rec-stable/output/OWN/crnn/acc_0.9766.pth"


    print("cuda?", torch.cuda.is_available())
    config = deal_cfg(path_cfg)
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)

    model = crnn.get_crnn(config).to(device)
    print('loading pretrained model from {0}'.format(path_pth))
    checkpoint = torch.load(path_pth)
    if 'state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)

    model.eval()

    img = torch.ones(1,1,32,320).cuda()

    img = img.type(torch.FloatTensor)

    img = img.to(device)
   


    ##################################################################
    traced_script_module = torch.jit.trace(model, img)
    traced_script_module.save("./acc.9833.pt")
    ##############################################################

这样转成功了,我的网络结构和上面那位大哥提供连接的结构不一样,所以我猜是结构的原因,导致推理不准确,
另外还是有一点是我在训练的时候添加了dropout层,在转换的时候我把dropout层给注销掉了,我也不知道该不该注销,先这样吧。接下来验证一下精度是否可以。

你可能感兴趣的:(libtorch,pytorch,深度学习,神经网络)