最近在做一个项目的某个模块,主要涉及文字识别的相关技术,文字识别主要分为两个步骤,文字检测与识别,本文主要针对文字识别的板块搭建模型,主流的就要属CRNN+CTC了。今天就送上案例实操,也是自己动手搭建的,分享一点心得。
做的过程中也是查看了许多相关文献和网站,这里主推一篇知乎文章,讲的真的很好,附上链接:一文读懂CRNN+CTC文字识别
要实现文字识别的最终落地,包括搭建模型,构造自己的datasets,然后开始training,本文先讲搭建模型。
废话不多说,直接上code。
在这里插入代码片import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
"""
每一个ResidualBlock,需要保证输入和输出的维度不变
所以卷积核的通道数都设置成一样
"""
def __init__(self, channel):
super().__init__()
self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
def forward(self, x):
"""
ResidualBlock中有跳跃连接;
在得到第二次卷积结果时,需要加上该残差块的输入,
再将结果进行激活,实现跳跃连接 ==> 可以避免梯度消失
在求导时,因为有加上原始的输入x,所以梯度为: dy + 1,在1附近
"""
y = F.relu(self.conv1(x))
y = self.conv2(y)
return F.relu(x + y)
class myLSTM(nn.Module):
def __init__(self,input_size,hidden_size,nout):
super(myLSTM, self).__init__()
self.lstm = nn.LSTM(input_size,hidden_size,num_layers=2,bidirectional=True)
self.linear = nn.Linear(2*hidden_size,nout)
def forward(self,input):
output,(h_n,c_n) = self.lstm(input)
T,B,H = output.size()
rec = output.view(T*B,H)
lout = self.linear(rec)
lout = lout.view(T,B,-1)#为了满足CTCloss的输入
return lout
class myCNN(nn.Module):
def __init__(self):
super(myCNN, self).__init__()
self.resblock1 = ResidualBlock(32)
self.resblock2 = ResidualBlock(64)
self.resblock3 = ResidualBlock(128)
self.conv1 = nn.Sequential(
#layer_1
nn.Conv2d(3,32,3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.conv2 = nn.Sequential(
#layer_2
nn.Conv2d(32,64,3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.conv3 = nn.Sequential(
# layer_3
nn.Conv2d(64, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(),
)
def forward(self,input):
out1 = self.conv1(input)
out1 = self.resblock1(out1)
out2 = self.conv2(out1)
out2 = self.resblock2(out2)
out3 = self.conv3(out2)
out3 = self.resblock3(out3)
return out3
class CRNN(nn.Module):
def __init__(self,nclass,nhidden):
super(CRNN, self).__init__()
self.cnn = nn.Sequential(myCNN())
self.lstm = nn.Sequential(
myLSTM(4*128,nhidden,nhidden),
myLSTM(nhidden,nhidden,nclass),
)
def forward(self,input):
conv = self.cnn(input)
batch,channel,h,w = conv.shape
#print(conv.shape)
conv = conv.permute(0,3,2,1)
# conv = conv.squeeze(dim =2)#[B,C,W]
conv = conv.reshape(batch, -1, 4*128)
conv = conv.permute(1,0,2)#input for lstm[T,N,C]
out = self.lstm(conv)
return out
图片大小Wx32x3,宽度不做要求但要求数据集统一大小,便于训练,如果H不同,可以参考卷积网络的输出自行修改即可,我这里经过卷积之后的H为4.
CNN模型我们搭建简单的ResNet,LSTM选用双向的,因此要注意输出。
代码基于pytorch框架,十分简洁明了。
码字不易,感谢点赞,下期出版训练代码以及数据集制作的代码。