CRNN 环境创建,复写代码讲解

1、环境创建

由于CRNN是在2015年发表的所以有些代码过于老旧,在此期间Pytorch自己更新了CTCLoss,所以只需要pytorch版本在1.0以上就可以使用pytorch自带的CTCloss,所以不需要按照CRNN中依赖链接,去进行编译。

并且warp-ctc这个链接库,需要Cmake和make进行编译,笔者在编译时出现了makefile文件没有生成,或者是dll文件没有生成。出现的问题千奇百怪,所以笔者建议直接使用1.0以上的pytorch,并将train.py中这一行删除。

from warpctc_pytorch import CTCLoss

笔者使用的python环境是3.8,其他包的版本如下所示:

certifi==2021.10.8
charset-normalizer==2.0.12
cycler==0.11.0
fonttools==4.30.0
idna==3.3
kiwisolver==1.3.2
lmdb==0.97
matplotlib==3.5.1
mkl-fft==1.3.0
mkl-random==1.1.1
mkl-service==2.3.0
numpy==1.17.3
opencv-python==4.5.5.64
packaging==21.3
Pillow==9.0.1
pyparsing==3.0.7
python-dateutil==2.8.2
requests==2.27.1
six==1.12.0
torch==1.11.0
torchvision==0.12.0
typing_extensions==4.1.1
urllib3==1.26.8
wincertstore==0.2

并且配置trainroot和valRoot这两个路径,训练集制作,参考这个:https://github.com/bgshih/crnn#train-a-new-model,作者也提供了预训练模型在这里https://pan.baidu.com/s/1pLbeCND

2、讲解CRNN以及复写代码

(1)CRNN模型如下所示:CRNN模型是由CNN和RNN组合而成的。开始使用多个CNN进行卷积以及选取batchnorm进行归一化和LeakyReLU进行非线性化,这里需要注意在这个模型中高度必须为32,因为输入数据在经过这么多个CNN卷积和池化后才能变为1同样通道数也由3变为512,然后在转录层中使用squeeze函数将高高度这个维度去掉,这里可以理解为输入为(Batch_size, channel, height,width) 经过多次卷积后变为(Batch_size, 512, 1,width), 压缩后变为(Batch_size, 512,width), 然后调整维度使用permute函数将维度变为(w, Batch_size,512),为什么要这样变化呢,这是因为对于RNN,接受的输入模式是(时间片,batch-size,以及维度),输出是(T, b, 2h)。由于选取的是双向LSTM所以传入数据为input,输出为nHidden2,使用view变为维度为(T*b,h),最后使用全连接变成每个预测分类的概率,再使用view变化shape为(B, T, -1),所以每个时间片预测处这个时间片最大的可能的类型是由,在CNN输出的channel决定的。

class Blstm(nn.Module):
    def __init__(self, input ,nHidden, output):
        super(Blstm, self).__init__()
        self.rnn = nn.LSTM(input, nHidden, bidirectional = True)
        self.embedding = nn.Linear(nHidden*2, output)

    def forward(self, x):
        recurrent,_ = self.rnn(x)
        T,b,h = recurrent.size()
        t_rec = recurrent.view(T*b,h)
        return self.embedding(t_rec).view(T,b,-1)


class CRNN(nn.Module):
    # 输入 图片高度,通道数,分类数量,隐藏层节点数,是否使用leakRelu
    def __init__(self, img_Height, nc ,n_class, nh,leakyRule =False):
        super(CRNN , self).__init__()
        # 高度必须为16的倍数
        assert img_Height%16 == 0

        # 配置
        KS = [3,3,3,3,3,3,2]
        PS = [1,1,1,1,1,1,0]
        SS = [1,1,1,1,1,1,1]
        nm = [64, 128,256,256,512,512,512]

        cnn = nn.Sequential()

        def convRELU(i ,batchNormalization = False):
            nIn = nc if i == 0 else nm[i-1]
            out = nm[i]

            cnn.add_module('conv{0}'.format(i), nn.Conv2d(nIn, out, KS[i],  SS[i],  PS[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(out))
            if leakyRule:
                cnn.add_module("leakyRule{0}".format(i), nn.LeakyReLU(0.2, inplace= True))
            else:
                cnn.add_module("relu{0}".format(i), nn.ReLU(True))

        convRELU(0)
        cnn.add_module("pooling{0}".format(0), nn.MaxPool2d((2,2)))
        convRELU(1)
        cnn.add_module("pooling{0}".format(1), nn.MaxPool2d((2, 2)))
        convRELU(2,True)
        convRELU(3)
        cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2,2), (2,1), (0,1)))
        convRELU(4,True)
        convRELU(5)
        cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1)))
        convRELU(6, True)
        self.cnn =cnn
        self.rnn = nn.Sequential(Blstm(512, nh,nh), Blstm(nh,nh,n_class))

    def forward(self, input):
        conv = self.cnn(input)
        b,c,h,w = conv.size()
        assert h == 1
        conv = conv.squeeze(2)
        conv = conv.permute(2,0,1) #[w,b,c]
        return  self.rnn(conv)

(2)类型字典类编写,主要用于编码解码,具体的代码如下所示:这个类在初始化时首先选择是否忽略大小写,使用变量ignore决定。制作编码数据字典时,我们要注意不要将0位占据,因为在Pytorch中CTCloss中0位表示为空格,当然也可以使用blank设置,具体的官方解说在这里。https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html#torch.nn.CTCLoss

在encode中主要是分为字符串类型和数组类型进行转换。

在编码中判断是一个字符串还是多个字符串,以及是否需要进行B变换(我也不知道这个字符怎么打,就暂时使用B变换),B变换的作用就是将空格去掉,以及将相邻且重复字符去掉,具体实现是由这条语句判断t[i]!=0 and not((i>0 and t[i-1] == t[i]))。如果时一个字符串就可以直接进行B变换了,如果多个字符串就进行迭代套娃,变为一个字符串进行编码。

class strLabelConverter():
    def __init__(self, alpha, ignore=True):
        self.ignore = ignore
        if self.ignore:
            alpha = alpha.lower()
        self.alpha = alpha +'-'
        self.dict = {}
        for i ,char in enumerate(alpha):
            #空出0因为CTCloss将0给了空格
            self.dict[char] = i+1
        print("self.dict", self.dict)

    def encode(self, text):
        if isinstance(text, str):
            text = [self.dict[item if self.ignore else item.lower()] for item in text]
            length = [len(text)]
        if isinstance(text, collections.Iterable):
            length = [len(i) for i in text]
            text = ''.join(text)
            text = [self.dict[item if self.ignore else item.lower()] for item in text]
        return torch.IntTensor(text) , torch.IntTensor(length)

    def decode(self, t, length, raw =False):
        if length.numel() == 1:
            length = length[0]
            assert  t.numel() == length
            if raw:
                return ''.join([self.alpha[i-1] for i in t])
            else:
                char_list =[]
                for i in range(length):
                    if t[i]!=0 and not((i>0 and t[i-1] == t[i])):
                        char_list.append((self.alpha[t[i] -1]))
                return ''.join(char_list)
        else:
            assert t.numel() == length.sum()
            texts = []
            index = 0
            for  i in range(t.numel()):
                l = length[i]
                texts.append(self.decode(t[index:index+1, torch.tensor([1])], raw= raw))
                index+=1
            return texts

(3)对图像进行处理,具体代码如下所示:CRNN网络对于输入图片的长度和宽度都有要求,所以需要进行图像处理 ,将高宽变为(32,100),具体原因在CRNN网络中,高度最终在卷积以及池化 的效果下会变为1,并在转录层去掉,对于宽度在论文中使用100经过池化和卷积变为16作为双层LSTM中的时间片使用,在这个网络作者的代码中一开始就使图像变为了灰度图所以通道数由1扩张为512.

class resizeTransfrom():
    def __init__(self, size , interpolation =Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation
        self.Intensor = transforms.ToTensor()

    def __call__(self, img):
        img = img.resize(self.size, self.interpolation)
        img = self.Intensor(img)
        return img.sub_(.5).div_(.5)

(4)dataset代码,由于在作者的代码中使用的是lmdb这个数据集,不过笔者觉得那个数据集太过于麻烦了,所以自己写了一个dataset如下所示:各位大佬勿喷。

class lmdbDataset(Dataset):
    # 文件目录
    def __init__(self, trainImg , trainGt ,transform=None, target_transform =None):
        #获取路径下的img文件名
        self.trainImgPath = trainImg
        self.trainGtPath = trainGt
        self.LabelList = os.listdir(trainGt)
        self.ImgList = os.listdir(trainImg)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return  len(self.LabelList)

    def __getitem__(self, index):
        assert  index <= len(self)
        img = Image.open(self.trainImgPath+ "\\"+self.ImgList[index]).convert('L')
        label_path = self.trainGtPath + "\\"+self.LabelList[index]
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            img = self.target_transform(img)
        with open(label_path, 'r') as f:
            label = f.read()

        return img, label

(5)在使用dataloader设置参数有个collate_fn,这样可以设置每一个batch的数据,以及可以设置如何取样本的,我们可以定义自己的函数来准确地实现想要的功能。同样为了使图片能够适应网络所以需要对图片进行处理。具体代码如下:在固定高度的条件下,同比例缩放宽度。

#改变图片大小,使其高度固定在32,宽度随着比例变化
class alignCollate(object):
    def __init__(self, imgH=32, imgW = 100, keep_ratio = False, min_ratio =1):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio = keep_ratio
        self.min_ratio = min_ratio

    def __call__(self, batch):
        imges ,labels = zip(*batch)
        if self.keep_ratio:
            ratio = []
            for img in imges:
                w,h =img.size()
                ratio.append(w/float(h))
            ratio.sort()
            max_ratio = ratio[-1]
            imgW = int(np.floor((max_ratio* self.imgH)))
        transform = resizeTransfrom((self.imgW, self.imgH))
        imges = [transform(img) for img in imges]
        imges = torch.cat([t.unsqueeze(0) for t in imges],0)
        return imges, labels

(6)到这里主要的类已经介绍完毕,完整代码如下所示:


import collections
import os.path
import random
import torch
from PIL import Image
import torchvision.transforms as transforms
import argparse
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim


class Blstm(nn.Module):
    def __init__(self, input ,nHidden, output):
        super(Blstm, self).__init__()
        self.rnn = nn.LSTM(input, nHidden, bidirectional = True)
        self.embedding = nn.Linear(nHidden*2, output)

    def forward(self, x):
        recurrent,_ = self.rnn(x)
        T,b,h = recurrent.size()
        t_rec = recurrent.view(T*b,h)
        return self.embedding(t_rec).view(T,b,-1)


class CRNN(nn.Module):
    # 输入 图片高度,通道数,分类数量,隐藏层节点数,是否使用leakRelu
    def __init__(self, img_Height, nc ,n_class, nh,leakyRule =False):
        super(CRNN , self).__init__()
        # 高度必须为16的倍数
        assert img_Height%16 == 0

        # 配置
        KS = [3,3,3,3,3,3,2]
        PS = [1,1,1,1,1,1,0]
        SS = [1,1,1,1,1,1,1]
        nm = [64, 128,256,256,512,512,512]

        cnn = nn.Sequential()

        def convRELU(i ,batchNormalization = False):
            nIn = nc if i == 0 else nm[i-1]
            out = nm[i]

            cnn.add_module('conv{0}'.format(i), nn.Conv2d(nIn, out, KS[i],  SS[i],  PS[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(out))
            if leakyRule:
                cnn.add_module("leakyRule{0}".format(i), nn.LeakyReLU(0.2, inplace= True))
            else:
                cnn.add_module("relu{0}".format(i), nn.ReLU(True))

        convRELU(0)
        cnn.add_module("pooling{0}".format(0), nn.MaxPool2d((2,2)))
        convRELU(1)
        cnn.add_module("pooling{0}".format(1), nn.MaxPool2d((2, 2)))
        convRELU(2,True)
        convRELU(3)
        cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2,2), (2,1), (0,1)))
        convRELU(4,True)
        convRELU(5)
        cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1)))
        convRELU(6, True)
        self.cnn =cnn
        self.rnn = nn.Sequential(Blstm(512, nh,nh), Blstm(nh,nh,n_class))

    def forward(self, input):
        conv = self.cnn(input)
        b,c,h,w = conv.size()
        assert h == 1
        conv = conv.squeeze(2)
        conv = conv.permute(2,0,1) #[w,b,c]
        return  self.rnn(conv)


class strLabelConverter():
    def __init__(self, alpha, ignore=True):
        self.ignore = ignore
        if self.ignore:
            alpha = alpha.lower()
        self.alpha = alpha +'-'
        self.dict = {}
        for i ,char in enumerate(alpha):
            #空出0因为CTCloss将0给了空格
            self.dict[char] = i+1
        print("self.dict", self.dict)

    def encode(self, text):
        if isinstance(text, str):
            text = [self.dict[item if self.ignore else item.lower()] for item in text]
            length = [len(text)]
        if isinstance(text, collections.Iterable):
            length = [len(i) for i in text]
            text = ''.join(text)
            text = [self.dict[item if self.ignore else item.lower()] for item in text]
        return torch.IntTensor(text) , torch.IntTensor(length)

    def decode(self, t, length, raw =False):
        print(t.numel(), '--->', length[0].data)
        if length.numel() == 1:
            length = length[0]
            assert  t.numel() == length
            if raw:
                return ''.join([self.alpha[i-1] for i in t])
            else:
                char_list =[]
                for i in range(length):
                    if t[i]!=0 and not((i>0 and t[i-1] == t[i])):
                        char_list.append((self.alpha[t[i] -1]))
                return ''.join(char_list)
        else:
            assert t.numel() == length.sum()
            texts = []
            index = 0
            for  i in range(t.numel()):
                l = length[i]
                texts.append(self.decode(t[index:index+1, torch.tensor([1])], raw= raw))
                index+=1
            return texts

class resizeTransfrom():
    def __init__(self, size , interpolation =Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation
        self.Intensor = transforms.ToTensor()

    def __call__(self, img):
        img = img.resize(self.size, self.interpolation)
        img = self.Intensor(img)
        return img.sub_(.5).div_(.5)


from torch.utils.data import Dataset

class lmdbDataset(Dataset):
    # 文件目录
    def __init__(self, trainImg , trainGt ,transform=None, target_transform =None):
        #获取路径下的img文件名
        self.trainImgPath = trainImg
        self.trainGtPath = trainGt
        self.LabelList = os.listdir(trainGt)
        self.ImgList = os.listdir(trainImg)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return  len(self.LabelList)

    def __getitem__(self, index):
        assert  index <= len(self)
        img = Image.open(self.trainImgPath+ "\\"+self.ImgList[index]).convert('L')
        label_path = self.trainGtPath + "\\"+self.LabelList[index]
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            img = self.target_transform(img)
        with open(label_path, 'r') as f:
            label = f.read()

        return img, label

#改变图片大小,使其高度固定在32,宽度随着比例变化
class alignCollate(object):
    def __init__(self, imgH=32, imgW = 100, keep_ratio = False, min_ratio =1):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio = keep_ratio
        self.min_ratio = min_ratio

    def __call__(self, batch):
        imges ,labels = zip(*batch)
        if self.keep_ratio:
            ratio = []
            for img in imges:
                w,h =img.size()
                ratio.append(w/float(h))
            ratio.sort()
            max_ratio = ratio[-1]
            imgW = int(np.floor((max_ratio* self.imgH)))
        transform = resizeTransfrom((self.imgW, self.imgH))
        imges = [transform(img) for img in imges]
        imges = torch.cat([t.unsqueeze(0) for t in imges],0)
        return imges, labels

#将值和数组大小直接赋值上去,不用多次创建变量
def loadData(v, data):
    v.data.resize_(data.size()).copy_(data)


if __name__ == '__main__':

    # -----------------------------------------------------------------------------
    # 模型训练代码如下所示
    # -----------------------------------------------------------------------------
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    parser = argparse.ArgumentParser()
    parser.add_argument('--trainRoot',default=r"E:\example\yanmacode\CRNN\crnn.pytorch\dataset",  help='path to dataset')
    parser.add_argument('--valRoot',default='E:\example\yanmacode\CRNN\crnn.pytorch\dataset',  help='path to dataset')
    parser.add_argument('--expr_dir', default='expr', help='Where to store samples and models')
    parser.add_argument('--manualSeed', type=int, default=1234, help='reproduce experiemnt')
    parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
    parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
    parser.add_argument('--imgW', type=int, default=100, help='the width of the input image to network')
    parser.add_argument('--epoch', type=int, default=25, help='number of epochs to train for')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
    parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
    parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')

    opt = parser.parse_args()
    if not os.path.exists(opt.expr_dir):
        os.makedirs(opt.expr_dir)
    # 生成随机数种子
    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    #Benchmark模式会提升计算速度,但是由于计算中有随机性,每次网络前馈结果略有差异。如果想要避免这种结果波动,设置:
    cudnn.benchmark = True
    trainImg = r'E:\example\yanmacode\CRNN\crnn.pytorch\dataset\train\JPEGImages'
    trainGt = r'E:\example\yanmacode\CRNN\crnn.pytorch\dataset\train\Annotations'
    dataset = lmdbDataset(trainImg, trainGt)
    assert  dataset
    fn = alignCollate()
    batch_size = 1
    trainLoader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, collate_fn=fn)
    nclass = len(alphabet)+1
    nc=1
    # 字典转换
    converter = strLabelConverter(alphabet)
    # 初始化模型
    crnn = CRNN(32,1,nclass, 256)
    #损失函数
    Loss = nn.CTCLoss()
    print(opt)
    #选择优化器。默认使用RMSprop
    if opt.adam:
        optimizer = optim.Adam(crnn.parameters(), lr=0.01,
                               betas=(opt.beta1, 0.999))
    elif opt.adadelta:
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=0.01)
    # 初始化权重
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    # 初始化模型参数
    if torch.cuda.is_available():
        crnn = crnn.cuda()
        optimizer = optimizer.cuda()
    crnn.apply(weights_init)
    # 训练轮数,model训练模式
    crnn.train()
    train_cost = 0
    for epoch in range(opt.epoch):
        print('----------------------')
        # 迭代训练集
        train_cost= 0
        for img,label in trainLoader:
            #赋值给imges,和text变量
            img = torch.FloatTensor(img)
            if torch.cuda.is_available():
                img = img.cuda()
                label = label.cuda()
            #预测
            preds = crnn(img)
            preds_size = torch.IntTensor([preds.size(0)] * batch_size)
            target, _= converter.encode(label)
            target = torch.tensor(target)
            getcode, length = converter.encode(label)
            cost = Loss(preds, target, preds_size, torch.IntTensor(length))
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()
            train_cost+= cost
        print('训练集所有的损失函数相加', train_cost)
        
        
# --------------------------------------用上面注释下面,反之同样----------------------


    
    # -----------------------------------------------------------------------------
    # 模型预测代码如下所示
    # -----------------------------------------------------------------------------

    model_path = './data/crnn.pth'
    img_path = './data/demo.png'
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    model = CRNN(32,1,37, 256)
    if torch.cuda.is_available():
        model = model.cuda()
    model.load_state_dict(torch.load(model_path))
    converter = strLabelConverter(alphabet)
    transform = resizeTransfrom((100,32))
    img = Image.open(img_path).convert('L
    img = transform(img)
    if torch.cuda.is_available():
        img = img.cuda()
    img = img.view(1, *img.size())
    model.eval()
    preds = model(img)
    _, preds = preds.max(2)
    preds = preds.transpose(1,0).contiguous().view(-1)
    preds_size = torch.IntTensor([preds.size(0)])
    raw_pred = converter.decode(preds.data, preds_size.data, raw =True)
    pred = converter.decode(preds.data, preds_size.data,raw = False)
    print(raw_pred,"-->", pred)

你可能感兴趣的:(CRNN 环境创建,复写代码讲解)