pytorch-crnn实践以及内置ctc_loss使用小结

最近开始深入OCR这块, 以前倒是训练过开源的Keras-CRNN, 但是它和原文还是不一样, 今天参照Keras-CRNN代码和CRNN论文用pytorch实现CRNN, 由于没有GPU, 自己造了100多张只包含数字的小图片来训练模型, 验证模型能否收敛

CRNN流程

在这儿不再详细谈CRNN论文了, 主要按照原文做一个流程描述:

  • 输入图片要求高度为32, 使用VGG提取特征,高度32倍下采样,因为要求最后高度维度为1,宽度可以根据情况来,宽度论文中是4倍下采样
  • CNN提取的特征为(batch_size, w/4, 1, 512), 挤压掉为1的维度后, 接上双向LSTM根据上下文特征进行预测
  • LSTM最终输出为(batch_size, w/4, 总字符类别), 即在每一个位置都会对属于所有字符的任意一个进行概率预测
  • 最后根据CTC_loss 进行计算损失

本文按照以上步骤展开

基本设置:

  • 任务背景: 数字识别 0-9, 加一个blank 共11个字符
  • 图片大小: 原CRNN中为 100 * 32 (宽 * 高), 本次实验环境下大小为 200 * 32
  • 识别字符的最长长度: 本次实验设置的最长长度为 20
  • 网络输出为 T 50(输入LSTM的数据的时间步, CNN 部分输出序列长度) * 11 (一共11个不同的字符, 有多少字符此处数字为多少)

网络构造

为了使用预训练的VGG权重, VGG backbone参照pytorch的VGG构造实现, 不然加载不了权重, 按照论文,第三层和第四层池化层核大小核步长改为(1, 2)

import torch.nn as nn
import torch.nn.functional as F


class VGG(nn.Module):

    def __init__(self):
        super(VGG, self).__init__()
        self.features = make_layers(cfgs['D'])

    def forward(self, x):
        x = self.features(x)
        return x


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    i = 0
    for v in cfg:
        if v == 'M':
            if i not in [9, 13]:
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))]

        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
        i += 1
    return nn.Sequential(*layers)


cfgs = {
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M']
}


class BidirectionalLSTM(nn.Module):

    def __init__(self, inp, nHidden, oup):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(inp, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, oup)

    def forward(self, x):
        out, _ = self.rnn(x)
        T, b, h = out.size()
        t_rec = out.view(T * b, h)

        output = self.embedding(t_rec)
        output = output.view(T, b, -1)

        return output


class CRNN(nn.Module):
    def __init__(self, characters_classes, hidden=256, pretrain=True):
        super(CRNN, self).__init__()
        self.characters_class = characters_classes
        self.body = VGG()
        # 将VGG stage5-1 卷积单独拿出来, 改了卷积核无法加载预训练参数
        self.stage5 = nn.Conv2d(512, 512, kernel_size=(3, 2), padding=(1, 0))
        self.hidden = hidden
        self.rnn = nn.Sequential(BidirectionalLSTM(512, self.hidden, self.hidden),
                                 BidirectionalLSTM(self.hidden, self.hidden, self.characters_class))

        self.pretrain = pretrain
        if self.pretrain:
            import torchvision.models.vgg as vgg
            pre_net = vgg.vgg16(pretrained=True)
            pretrained_dict = pre_net.state_dict()
            model_dict = self.body.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            self.body.load_state_dict(model_dict)

            for param in self.body.parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.body(x)
        x = self.stage5(x)
        # 挤压掉高所在的维度
        x = x.squeeze(3)
        # 转换为LSTM所需格式
        x = x.permute(2, 0, 1).contiguous()
        x = self.rnn(x)
        x = F.log_softmax(x, dim=2)
        return x

数据加载

数据集格式参照 MJSynth 数据集格式

import os
import cv2
import numpy as np
from torch.utils.data import Dataset


class RegDataSet(Dataset):
    def __init__(self, dataset_root, anno_txt_path, lexicon_path, target_size=(200, 32), characters="'-' + '0123456789'", transform=None):
        super(RegDataSet, self).__init__()
        self.dataset_root = dataset_root
        self.anno_txt_path = anno_txt_path
        self.lexicon_path = lexicon_path
        self.target_size = target_size
        self.height = self.target_size[1]
        self.width = self.target_size[0]
        self.characters = characters
        self.imgs = []
        self.lexicons = []
        self.parse_txt()
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, item):
        img_path, lexicon_index = self.imgs[item].split()
        lexicon = self.lexicons[int(lexicon_index)].strip()
        img = cv2.imread(os.path.join(self.dataset_root, img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_size = img.shape

        if (img_size[1] / (img_size[0] * 1.0)) < 6.4:
            img_reshape = cv2.resize(img, (int(32.0 / img_size[0] * img_size[1]), self.height))
            mat_ori = np.zeros((self.height, self.width - int(32.0 / img_size[0] * img_size[1]), 3), dtype=np.uint8)
            out_img = np.concatenate([img_reshape, mat_ori], axis=1).transpose([1, 0, 2])
        else:
            out_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_CUBIC)
            out_img = np.asarray(out_img).transpose([1, 0, 2])

        label = [self.characters.find(c) for c in lexicon]
        if self.transform:
            out_img = self.transform(out_img)
        return out_img, label

    def parse_txt(self):
        self.imgs = open(os.path.join(self.dataset_root, self.anno_txt_path), 'r').readlines()
        self.lexicons = open(os.path.join(self.dataset_root, self.lexicon_path), 'r').readlines()

CTC LOSS

characters = "-0123456789"

ctc_loss = CTCLoss(blank=0, reduction='mean')
  • blank: 占位符 ‘-’ 所在的索引, 上例中为0
  • reduction: 处理loss的方式

损失计算:

ctc_loss(log_probs, targets, input_lengths, target_lengths)
  • log_probs: 网络输出的tensor, shape为 (T, N, C), T 为时间步, N 为batch_size, C 为字符总数(包括blank). 本例中,假如batch_size=8,网络输出为 (50,8,11)。网络输出需要进行log_softmax
  • targets: 目标tensor, targets有两种输入形式。其一: shape为 (N,S),N为batch_size,S 为识别序列的最长长度,值为每一个字符的index,不能包含blank的index。由于可能每个序列的长度不一样,而数组必须维度一样,就需要将短的序列padded 为最长序列长度(不过怎么padded没太弄明白,TensorFlow CTC_Loss 里面使用blank去填充, 但是这儿说了不能包含blank,有点迷糊, 还是用第二种吧)。 其二: 将该batch_size 内每一张图片的字符的index拼成一个一维数组. 会按照target_lengths 中的值自动对该一维数组中的index进行划分到对应图片
  • target_lengths: shape 为(N) 的Tensor, 每一个位置记录了对应图片所含有的字符数. 假如 N=4,即共有4张图片,每一张图片中包含的字符个数分别为: 8, 10, 12, 20, 那么 target_lengths = (8, 10, 12, 20), 同时targets 中共有 (8 + 10 + 12 + 20)个值,按照target_lengths中的值依次在targets 中取值即可
  • input_lengths: shape 为 (N) 的Tensor, 值为输出序列长度T, 因为图片宽度都固定了,所以都为T

个人实现代码如下:

def custom_collate_fn(batch, T=50):
        items = list(zip(*batch))
        items[0] = default_collate(items[0])
        labels = list(items[1])
        items[1] = []
        target_lengths = torch.zeros((len(batch,)), dtype=torch.int)
        input_lengths = torch.zeros(len(batch,), dtype=torch.int)
        for idx, label in enumerate(labels):
            # 记录每个图片对应的字符总数
            target_lengths[idx] = len(label)
            # 将batch内的label拼成一个list
            items[1].extend(label)
            # input_lengths 恒为 T
            input_lengths[idx] = T

        return items[0], torch.tensor(items[1]), target_lengths, input_lengths

batch_iterator = iter(DataLoader(trainSet, args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=custom_collate_fn))
images, labels, target_lengths, input_lengths = next(batch_iterator)
out = net(images)
loss = ctc_loss(log_probs=out, targets=labels, target_lengths=target_lengths, input_lengths=input_lengths)

解码

网络输出为 (50, 11), 解码先取每个位置的最大概率的字符index, index转str时,如果两个相同的index连续,那么合并为一个

例: 假设输出为: [1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 3, 0, 7, 7, 7, 0, 3, 3] , 由于后边全为0,只取前18位. 0 对应的字符是 ‘-’, 对于相邻的非0字符, 看做一个字符, 因此该例子为 [1,0,0,1,0,0,2,0,3,0,7,0,3], 再将0对应的blank 去掉, 则为实际的字符index为 [1,1,2,3,7,3]

def decode_out(str_index, characters):
    char_list = []
    for i in range(len(str_index)):
        if str_index[i] != 0 and (not (i > 0 and str_index[i - 1] == str_index[i])):
            char_list.append(characters[str_index[i]])
    return ''.join(char_list)

net_out = net(img)
_, preds = net_out.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
lab2str = decode_out(preds, args.characters)

完整代码

你可能感兴趣的:(计算机视觉)