cuda11.1_torch1.8_python3_crnn训练小记

1、制作lmdb数据集

由于数据集制作不涉及核心,我就直接引用了python2的代码,当然有时间的同学们也可以自己改成python3的。
先上环境,我使用anaconda新建了一个py27的环境(注意:opencv最后一次对python2的维护止于python2.7!!!)
这是我的pip list:

Package       Version            
------------- -------------------
certifi       2020.6.20          
lmdb          1.3.0              
numpy         1.16.6             
opencv-python 4.2.0.32           
pip           19.3.1             
setuptools    44.0.0.post20200106
wheel         0.37.1    

接着上数据集:
cuda11.1_torch1.8_python3_crnn训练小记_第1张图片
关于数据集的标注,大家可使用labelimg,标签写成对应的内容就行,制作好数据集后,根据对应的坐标一次将子图抠出来在生成对应的txt文件标签就ok了。

from __future__ import division

import os

from PIL import Image

import xml.dom.minidom

import numpy as np

ImgPath = './images/'

AnnoPath = './annotations/'

ProcessedPath_jpg = './crops/'
ProcessedPath_txt = './labels/'

imagelist = os.listdir(ImgPath)



for image in imagelist:

    image_pre, ext = os.path.splitext(image)

    imgfile = ImgPath + image

    print("imgfile------->",imgfile)
    print("image_pre------->",image_pre)
    print("image------->", image)

    if not os.path.exists(AnnoPath + image_pre + '.xml'): continue

    xmlfile = AnnoPath + image_pre + '.xml'

    DomTree = xml.dom.minidom.parse(xmlfile)

    annotation = DomTree.documentElement

    filenamelist = annotation.getElementsByTagName('filename')  # []


    objectlist = annotation.getElementsByTagName('object')

    i = 1

    for objects in objectlist:

        namelist = objects.getElementsByTagName('name')

        objectname = namelist[0].childNodes[0].data

        savepath_jpg = ProcessedPath_jpg
        savepath_txt = ProcessedPath_txt

        print("savepath_jpg------->",savepath_jpg)
        print("savepath_txt------->", savepath_txt)
        print("objectname------->", objectname)


        if not os.path.exists(savepath_jpg):
            os.makedirs(savepath_jpg)

        if not os.path.exists(savepath_txt):
            os.makedirs(savepath_txt)

        bndbox = objects.getElementsByTagName('bndbox')

        cropboxes = []

        for box in bndbox:

            x1_list = box.getElementsByTagName('xmin')

            x1 = int(x1_list[0].childNodes[0].data)

            y1_list = box.getElementsByTagName('ymin')

            y1 = int(y1_list[0].childNodes[0].data)

            x2_list = box.getElementsByTagName('xmax')

            x2 = int(x2_list[0].childNodes[0].data)

            y2_list = box.getElementsByTagName('ymax')

            y2 = int(y2_list[0].childNodes[0].data)

            w = x2 - x1

            h = y2 - y1

            bbox_name_cor=str(x1)+"_"+str(y1)+"_"+str(x2)+"_"+str(y2)+"_"
            print("bbox_name_cor------->",bbox_name_cor)

            obj = np.array([x1, y1, x2, y2])

            shift = np.array([[1, 1, 1, 1]])

            XYmatrix = np.tile(obj, (1, 1))

            cropboxes = XYmatrix * shift

            img = Image.open(imgfile)

            for cropbox in cropboxes:
                cropedimg = img.crop(cropbox)
                save_item_jpg=savepath_jpg +bbox_name_cor+ objectname+"_"+image_pre + '_' + str(i) + '.jpg'
                save_item_txt=savepath_txt +bbox_name_cor+ objectname+"_"+image_pre + '_' + str(i) + '.txt'
                out_file = open(save_item_txt, 'w')
                print("save_item_jpg----->",save_item_jpg)
                print("save_item_txt----.",save_item_txt)

                print("save_item_jpg---->",save_item_jpg)
                print("str(objectname)---->", str(objectname))

                cropedimg.save(save_item_jpg)
                out_file.write(str(objectname))
                i += 1

拿到数据集后,确保txt和图片在一个文件夹内,使用python2.7运行以下代码:

# -*- coding: utf-8 -*-
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
#from genLineText import GenTextImage
 
def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.fromstring(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return False
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True
 
 
def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.iteritems():
            txn.put(k, v)

 
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    #print (len(imagePathList) , len(labelList))
    assert(len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    print '...................'
    # map_size=1099511627776 定义最大空间是1TB
    env = lmdb.open(outputPath, map_size=1099511627776)
    
cache = {}
cnt = 1
for i in xrange(nSamples):
    imagePath = imagePathList[i]
    label = labelList[i]
    if not os.path.exists(imagePath):
        print('%s does not exist' % imagePath)
        continue
    with open(imagePath, 'r') as f:
        imageBin = f.read()
    if checkValid:
        if not checkImageIsValid(imageBin):
            print('%s is not a valid image' % imagePath)
            continue


    ########## .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
    imageKey = 'image-%09d' % cnt
    labelKey = 'label-%09d' % cnt
    cache[imageKey] = imageBin
    cache[labelKey] = label
    ##########
    if lexiconList:
        lexiconKey = 'lexicon-%09d' % cnt
        cache[lexiconKey] = ' '.join(lexiconList[i])
    if cnt % 1000 == 0:
        writeCache(env, cache)
        cache = {}
        print('Written %d / %d' % (cnt, nSamples))
    cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)


def read_text(path):

with open(path) as f:
    text = f.read()
text = text.strip()

return text


import glob
if __name__ == '__main__':

#lmdb 输出目录
outputPath = './lmdb3_sample'

# 训练图片路径,标签是txt格式,名字跟图片名字要一致,如123.jpg对应标签需要是123.txt
path = './data_sample/*.png'

imagePathList = glob.glob(path)
print '------------',len(imagePathList),'------------'
imgLabelLists = []
for p in imagePathList:
    try:
       imgLabelLists.append((p,read_text(p.replace('.png','.txt'))))
    except:
        continue
        
#imgLabelList = [ (p,read_text(p.replace('.jpg','.txt'))) for p in imagePathList]
##sort by lebelList
imgLabelList = sorted(imgLabelLists,key = lambda x:len(x[1]))
imgPaths = [ p[0] for p in imgLabelList]
txtLists = [ p[1] for p in imgLabelList]

createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)

运行完后,生成如下文件:
在这里插入图片描述

接下来就可以训练了

2、训练

可以从/meijieru/crnn.pytorch顺手拿一份代码,拿完后记得修改,不然我们使用的高版本的cuda、torch等会报错的,主要更新掉了variable以及部分其他bug,当然有一部分bug是参考这位小哥改的:

https://www.cnblogs.com/yanghailin/p/14519525.html

问题不大

1、train.py

from __future__ import print_function
from __future__ import division

import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
import os
import utils
import dataset
import models.crnn as crnn

parser = argparse.ArgumentParser()
parser.add_argument('--trainRoot', default="./data/lmdb/", help='path to dataset')
parser.add_argument('--valRoot', default="./data/lmdb/", help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=5, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for')
# TODO(meijieru): epoch -> iter
parser.add_argument('--cuda', default=True, help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--pretrained', default='', help="path to pretrained model (to continue training)")
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')
parser.add_argument('--expr_dir', default='expr', help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=100, help='Interval to be displayed')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, not used by adadealta')
parser.add_argument('--beta1', type=float, default=0.5, help='betaVariable1 for adam. default=0.5')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', default=True, help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--manualSeed', type=int, default=1234, help='reproduce experiemnt')
parser.add_argument('--random_sample',default=0, action='store_true', help='whether to sample the dataset with random sampler')
opt = parser.parse_args()
print(opt)

if not os.path.exists(opt.expr_dir):
    os.makedirs(opt.expr_dir)

random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

train_dataset = dataset.lmdbDataset(root=opt.trainRoot)
assert train_dataset
if  opt.random_sample :
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batchSize,sampler=sampler,
        num_workers=int(opt.workers),
        collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
else:

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=True, sampler=None,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))



Val_dataset = dataset.lmdbDataset(
    root=opt.valRoot, transform=dataset.resizeNormalize((100, 32)))
Valdata_loader = torch.utils.data.DataLoader(
    Val_dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))


nclass = len(opt.alphabet) + 1
nc = 1

converter = utils.strLabelConverter(opt.alphabet)

criterion = torch.nn.CTCLoss()
# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.pretrained != '':
    import collections
    print('loading pretrained model from %s' % opt.pretrained)
    load_model_ = torch.load(opt.pretrained)
    state_dict_rename = collections.OrderedDict()
    for k, v in load_model_.items():
        name = k[7:]  # remove `module.`
        state_dict_rename[name] = v

crnn.load_state_dict(state_dict_rename)
print(crnn)

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgW)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)

if opt.cuda:
    crnn.cuda()
    crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
    image = image.cuda()
    criterion = criterion.cuda()



# loss averager
loss_avg = utils.averager()

# setup optimizer
if opt.adam:
    optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
                           betas=(opt.beta1, 0.999))
elif opt.adadelta:
    optimizer = optim.Adadelta(crnn.parameters())
else:
    optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)


def val(net, Valdata_loader, criterion, max_iter=100):
    print('Start val')
    net.eval()
    val_iter = iter(Valdata_loader)
    n_correct = 0
    loss_avg = utils.averager()
    max_iter = min(max_iter, len(Valdata_loader))
    with torch.no_grad():

    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)

        preds = crnn(image)



        preds_size = torch.IntTensor([preds.size(0)] * batch_size)


        cost = criterion(preds, text, preds_size ,length) / batch_size
        loss_avg.add(cost)

        _, preds = preds.max(2)
        preds = preds.squeeze(1)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred == target.lower():
                n_correct += 1



raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
    print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

accuracy = n_correct / float(max_iter * opt.batchSize)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))


def trainBatch(net, criterion, optimizer):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

preds = crnn(image)


preds_size = torch.IntTensor([preds.size(0)] * batch_size)

cost = criterion(preds, text, preds_size, length) / batch_size
crnn.zero_grad()
cost.backward()
optimizer.step()
return cost


for epoch in range(opt.nepoch):
    train_iter = iter(train_loader)
    i = 0
    while i < len(train_loader):
        for p in crnn.parameters():
            p.requires_grad = True
        crnn.train()

    cost = trainBatch(crnn, criterion, optimizer)
    loss_avg.add(cost)
    i += 1

    if i % opt.displayInterval == 0:
        print('[%d/%d][%d/%d] Loss: %f' %
              (epoch, opt.nepoch, i, len(train_loader), loss_avg.val()))
        loss_avg.reset()

    if i % opt.valInterval == 0 :
        val(crnn, Valdata_loader, criterion)

    # do checkpointing
    if i % opt.saveInterval == 0:
        torch.save(
            crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.expr_dir, epoch, i))

    if (0 != epoch) and (epoch % 100 ==0) and  (1 == i):
        torch.save(
            crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.expr_dir, epoch, i))

2、dataset.py

#!/usr/bin/python
# encoding: utf-8

import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import numpy as np




class lmdbDataset(Dataset):

def __init__(self, root=None, transform=None, target_transform=None):
    self.env = lmdb.open(
        root,
        max_readers=1,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False)

    if not self.env:
        print('cannot creat lmdb from %s' % (root))
        sys.exit(0)

    with self.env.begin(write=False) as txn:
        nSamples = int(txn.get('num-samples'.encode()))
        print("nSamples===================",nSamples)
        self.nSamples = nSamples

    self.transform = transform
    self.target_transform = target_transform

def __len__(self):
    return self.nSamples

def __getitem__(self, index):
    assert index <= len(self), 'index range error'
    index += 1
    with self.env.begin(write=False) as txn:
        img_key = 'image-%09d' % index
        imgbuf = txn.get(img_key.encode())

        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        try:
            img = Image.open(buf).convert('L')
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        if self.transform is not None:
            img = self.transform(img)

        label_key = 'label-%09d' % index
        label_byte = txn.get(label_key.encode())  ################33
        label = label_byte.decode()

        if self.target_transform is not None:
            label = self.target_transform(label)

    return (img, label)


class resizeNormalize(object):

def __init__(self, size, interpolation=Image.BILINEAR):
    self.size = size
    self.interpolation = interpolation
    self.toTensor = transforms.ToTensor()

def __call__(self, img):
    img = img.resize(self.size, self.interpolation)
    img = self.toTensor(img)
    img.sub_(0.5).div_(0.5)
    return img


class randomSequentialSampler(sampler.Sampler):

def __init__(self, data_source, batch_size):
    self.num_samples = len(data_source)
    self.batch_size = batch_size

def __iter__(self):
    n_batch = len(self) // self.batch_size
    tail = len(self) % self.batch_size
    index = torch.LongTensor(len(self)).fill_(0)
    for i in range(n_batch):
        random_start = random.randint(0, len(self) - self.batch_size)
        # batch_index = random_start + torch.range(0, self.batch_size - 1)

        batch_index = random_start + torch.range(0, self.batch_size - 1)


        index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
    # deal with tail
    if tail:
        random_start = random.randint(0, len(self) - self.batch_size)
        tail_index = random_start + torch.range(0, tail - 1)
        index[(i + 1) * self.batch_size:] = tail_index

    return iter(index)

def __len__(self):
    return self.num_samples


class alignCollate(object):

def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
    self.imgH = imgH
    self.imgW = imgW
    self.keep_ratio = keep_ratio
    self.min_ratio = min_ratio

def __call__(self, batch):
    images, labels = zip(*batch)

    imgH = self.imgH
    imgW = self.imgW
    if self.keep_ratio:
        ratios = []
        for image in images:
            w, h = image.size
            ratios.append(w / float(h))
        ratios.sort()
        max_ratio = ratios[-1]
        imgW = int(np.floor(max_ratio * imgH))
        imgW = max(imgH * self.min_ratio, imgW)  # assure imgH >= imgW

    transform = resizeNormalize((imgW, imgH))
    images = [transform(image) for image in images]
    images = torch.cat([t.unsqueeze(0) for t in images], 0)

    return images, labels

3、utils.py

#!/usr/bin/python
# encoding: utf-8
import time

import torch
import torch.nn as nn
import collections


class strLabelConverter(object):
    """Convert between str and label.

NOTE:
    Insert `blank` to the alphabet for CTC.

Args:
    alphabet (str): set of the possible characters.
    ignore_case (bool, default=True): whether or not to ignore all of the case.
"""

def __init__(self, alphabet, ignore_case=True):
    self._ignore_case = ignore_case
    if self._ignore_case:
        alphabet = alphabet.lower()
    self.alphabet = alphabet + '-'  # for `-1` index

    self.dict = {}
    for i, char in enumerate(alphabet):
        # NOTE: 0 is reserved for 'blank' required by wrap_ctc
        self.dict[char] = i + 1

def encode(self, text):
    """Support batch or single str.

    Args:
        text (str or list of str): texts to convert.

    Returns:
        torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
        torch.IntTensor [n]: length of each text.
    """
    if isinstance(text, str):
        text = [
            self.dict[char.lower() if self._ignore_case else char]
            for char in text
        ]
        length = [len(text)]
    elif isinstance(text, collections.Iterable):
        length = [len(s) for s in text]
        text = ''.join(text)
        text, _ = self.encode(text)
    return (torch.IntTensor(text), torch.IntTensor(length))

def decode(self, t, length, raw=False):
    """Decode encoded texts back into strs.

    Args:
        torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
        torch.IntTensor [n]: length of each text.

    Raises:
        AssertionError: when the texts and its length does not match.

    Returns:
        text (str or list of str): texts to convert.
    """
    if length.numel() == 1:
        length = length[0]
        assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
        if raw:
            return ''.join([self.alphabet[i - 1] for i in t])
        else:
            char_list = []
            for i in range(length):
                if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                    char_list.append(self.alphabet[t[i] - 1])
            return ''.join(char_list)
    else:
        # batch mode
        assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
        texts = []
        index = 0
        for i in range(length.numel()):
            l = length[i]
            texts.append(
                self.decode(
                    t[index:index + l], torch.IntTensor([l]), raw=raw))
            index += l
        return texts


class averager(object):
    """Compute average"""

def __init__(self):
    self.n_count = 0
    self.sum = 0
    self.reset()


def add(self, v):
    count = v.numel()
    v = v.sum()
    self.n_count += count
    self.sum += v

def reset(self):
    self.n_count = 0
    self.sum = 0

def val(self):
    res=0
    if self.n_count != 0:
        res = self.sum / float(self.n_count)
    return res


def oneHot(v, v_length, nc):
    batchSize = v_length.size(0)
    maxLength = v_length.max()
    v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
    acc = 0
    for i in range(batchSize):
        length = v_length[i]
        label = v[acc:acc + length].view(-1, 1).long()
        v_onehot[i, :length].scatter_(1, label, 1.0)
        acc += length
    return v_onehot


def loadData(v, data):
    v.resize_(data.size()).copy_(data)#v.data.resize_(data.size()).copy_(data)


def prettyPrint(v):
    print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
    print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0],
                                              v.mean().data[0]))


def assureRatio(img):
    """Ensure imgH <= imgW."""
    b, c, h, w = img.size()
    if h > w:
        main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
        img = main(img)
    return img

3、测试

建一个test1的文件夹,里面放上你的训练集图片,嘿嘿,也可放没训练的,测试代码如下:

import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
import matplotlib.pyplot as plt
import collections
import os

import models.crnn as crnn


model_path = 'netCRNN_1400_1.pth'

alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


dir_img = "./test1/"

nclass = len(alphabet) + 1

model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
    model = model.cuda()


load_model_ = torch.load(model_path)


state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():
    name = k[7:] # remove `module.`
    state_dict_rename[name] = v


print('loading pretrained model from %s' % model_path)
model.load_state_dict(state_dict_rename)



converter = utils.strLabelConverter(alphabet)

transformer = dataset.resizeNormalize((100, 32))



list_img = os.listdir(dir_img)
for cnt,img_name in enumerate(list_img):
    print(cnt,img_name)
    path_img = dir_img + img_name

image = Image.open(path_img).convert('L')
image = transformer(image)
if torch.cuda.is_available():
    image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)

model.eval()
preds = model(image)

_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)

preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('%-20s => %-20s' % (raw_pred, sim_pred))

image_show = Image.open(path_img)
plt.figure("show")
plt.imshow(image_show)
plt.show()

以上全部结束,接下来就是tensorrtx的加速的了,大家可以自行参考trtx
大家可以加群一起探讨
cuda11.1_torch1.8_python3_crnn训练小记_第2张图片

你可能感兴趣的:(crnn,机器学习,torch,python,人工智能,深度学习)