最近开始深入OCR这块, 以前倒是训练过开源的Keras-CRNN, 但是它和原文还是不一样, 今天参照Keras-CRNN代码和CRNN论文用pytorch实现CRNN, 由于没有GPU, 自己造了100多张只包含数字的小图片来训练模型, 验证模型能否收敛
在这儿不再详细谈CRNN论文了, 主要按照原文做一个流程描述:
本文按照以上步骤展开
基本设置:
为了使用预训练的VGG权重, VGG backbone参照pytorch的VGG构造实现, 不然加载不了权重, 按照论文,第三层和第四层池化层核大小核步长改为(1, 2)
import torch.nn as nn
import torch.nn.functional as F
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.features = make_layers(cfgs['D'])
def forward(self, x):
x = self.features(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
i = 0
for v in cfg:
if v == 'M':
if i not in [9, 13]:
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
i += 1
return nn.Sequential(*layers)
cfgs = {
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M']
}
class BidirectionalLSTM(nn.Module):
def __init__(self, inp, nHidden, oup):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(inp, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, oup)
def forward(self, x):
out, _ = self.rnn(x)
T, b, h = out.size()
t_rec = out.view(T * b, h)
output = self.embedding(t_rec)
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, characters_classes, hidden=256, pretrain=True):
super(CRNN, self).__init__()
self.characters_class = characters_classes
self.body = VGG()
# 将VGG stage5-1 卷积单独拿出来, 改了卷积核无法加载预训练参数
self.stage5 = nn.Conv2d(512, 512, kernel_size=(3, 2), padding=(1, 0))
self.hidden = hidden
self.rnn = nn.Sequential(BidirectionalLSTM(512, self.hidden, self.hidden),
BidirectionalLSTM(self.hidden, self.hidden, self.characters_class))
self.pretrain = pretrain
if self.pretrain:
import torchvision.models.vgg as vgg
pre_net = vgg.vgg16(pretrained=True)
pretrained_dict = pre_net.state_dict()
model_dict = self.body.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.body.load_state_dict(model_dict)
for param in self.body.parameters():
param.requires_grad = False
def forward(self, x):
x = self.body(x)
x = self.stage5(x)
# 挤压掉高所在的维度
x = x.squeeze(3)
# 转换为LSTM所需格式
x = x.permute(2, 0, 1).contiguous()
x = self.rnn(x)
x = F.log_softmax(x, dim=2)
return x
数据集格式参照 MJSynth 数据集格式
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
class RegDataSet(Dataset):
def __init__(self, dataset_root, anno_txt_path, lexicon_path, target_size=(200, 32), characters="'-' + '0123456789'", transform=None):
super(RegDataSet, self).__init__()
self.dataset_root = dataset_root
self.anno_txt_path = anno_txt_path
self.lexicon_path = lexicon_path
self.target_size = target_size
self.height = self.target_size[1]
self.width = self.target_size[0]
self.characters = characters
self.imgs = []
self.lexicons = []
self.parse_txt()
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, item):
img_path, lexicon_index = self.imgs[item].split()
lexicon = self.lexicons[int(lexicon_index)].strip()
img = cv2.imread(os.path.join(self.dataset_root, img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_size = img.shape
if (img_size[1] / (img_size[0] * 1.0)) < 6.4:
img_reshape = cv2.resize(img, (int(32.0 / img_size[0] * img_size[1]), self.height))
mat_ori = np.zeros((self.height, self.width - int(32.0 / img_size[0] * img_size[1]), 3), dtype=np.uint8)
out_img = np.concatenate([img_reshape, mat_ori], axis=1).transpose([1, 0, 2])
else:
out_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_CUBIC)
out_img = np.asarray(out_img).transpose([1, 0, 2])
label = [self.characters.find(c) for c in lexicon]
if self.transform:
out_img = self.transform(out_img)
return out_img, label
def parse_txt(self):
self.imgs = open(os.path.join(self.dataset_root, self.anno_txt_path), 'r').readlines()
self.lexicons = open(os.path.join(self.dataset_root, self.lexicon_path), 'r').readlines()
characters = "-0123456789"
ctc_loss = CTCLoss(blank=0, reduction='mean')
损失计算:
ctc_loss(log_probs, targets, input_lengths, target_lengths)
个人实现代码如下:
def custom_collate_fn(batch, T=50):
items = list(zip(*batch))
items[0] = default_collate(items[0])
labels = list(items[1])
items[1] = []
target_lengths = torch.zeros((len(batch,)), dtype=torch.int)
input_lengths = torch.zeros(len(batch,), dtype=torch.int)
for idx, label in enumerate(labels):
# 记录每个图片对应的字符总数
target_lengths[idx] = len(label)
# 将batch内的label拼成一个list
items[1].extend(label)
# input_lengths 恒为 T
input_lengths[idx] = T
return items[0], torch.tensor(items[1]), target_lengths, input_lengths
batch_iterator = iter(DataLoader(trainSet, args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=custom_collate_fn))
images, labels, target_lengths, input_lengths = next(batch_iterator)
out = net(images)
loss = ctc_loss(log_probs=out, targets=labels, target_lengths=target_lengths, input_lengths=input_lengths)
网络输出为 (50, 11), 解码先取每个位置的最大概率的字符index, index转str时,如果两个相同的index连续,那么合并为一个
例: 假设输出为: [1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 3, 0, 7, 7, 7, 0, 3, 3] , 由于后边全为0,只取前18位. 0 对应的字符是 ‘-’, 对于相邻的非0字符, 看做一个字符, 因此该例子为 [1,0,0,1,0,0,2,0,3,0,7,0,3], 再将0对应的blank 去掉, 则为实际的字符index为 [1,1,2,3,7,3]
def decode_out(str_index, characters):
char_list = []
for i in range(len(str_index)):
if str_index[i] != 0 and (not (i > 0 and str_index[i - 1] == str_index[i])):
char_list.append(characters[str_index[i]])
return ''.join(char_list)
net_out = net(img)
_, preds = net_out.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
lab2str = decode_out(preds, args.characters)