引言
- CRNN是经典的文本识别算法,这里主要用来夯实基础,掌握CRNN基本原理以及PyTorch实现。
基本原理
核心代码实现
import torch
from torch import nn
import torch.nn.functional as F
class CRNN(nn.Module):
def __init__(self, img_height, input_channel, n_class, hidden_size):
super().__init__()
if img_height % 16 != 0:
raise ValueError('img_height has to be a multiple of 16')
kernel_size = [3, 3, 3, 3, 3, 3, 2]
padding_size = [1, 1, 1, 1, 1, 1, 0]
stride = [1, 1, 1, 1, 1, 1, 1]
channel = [64, 128, 256, 256, 512, 512, 512]
def conv_relu(i, batchNormalization=False):
in_channels = input_channel if i == 0 else channel[i - 1]
out_channels = channel[i]
cnn.add_module(f'conv{i}',
nn.Conv2d(in_channels, out_channels,
kernel_size[i],
stride[i],
padding_size[i]))
if batchNormalization:
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(out_channels))
cnn.add_module(f'relu{i}', nn.ReLU(True))
cnn = nn.Sequential()
conv_relu(0)
cnn.add_module('pooling0', nn.MaxPool2d(2, 2))
conv_relu(1)
cnn.add_module('pooling1', nn.MaxPool2d(2, 2))
conv_relu(2, True)
conv_relu(3)
cnn.add_module('pooling2',
nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 1),
padding=(0, 1)))
conv_relu(4, True)
conv_relu(5)
cnn.add_module('pooling3',
nn.MaxPool2d((2, 2), (2, 1), (0, 1)))
conv_relu(6, True)
self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, hidden_size, hidden_size),
BidirectionalLSTM(hidden_size, hidden_size, n_class)
)
def forward(self, x):
cnn_feature = self.cnn(x)
h = cnn_feature.size()[2]
if h != 1:
raise ValueError("the height of cnn_feature must be 1")
cnn_feature = cnn_feature.squeeze(2)
cnn_feature = cnn_feature.permute(2, 0, 1)
output = self.rnn(cnn_feature)
x = F.log_softmax(x, dim=2)
return output
class BidirectionalLSTM(nn.Module):
def __init__(self, input_size, hidden_size, out_feature):
super().__init__()
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
self.embedding = nn.Linear(hidden_size * 2, out_feature)
def forward(self, x):
recurrent, _ = self.rnn(x)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec)
output = output.view(T, b, -1)
return output
if __name__ == '__main__':
img = torch.randn((1, 1, 32, 320))
crnn = CRNN(32, 1, 26, 256)
res = crnn(img)
print(res.shape)