夯实基础系列:CRNN

引言

  • CRNN是经典的文本识别算法,这里主要用来夯实基础,掌握CRNN基本原理以及PyTorch实现。

基本原理

核心代码实现

import torch
from torch import nn
import torch.nn.functional as F


class CRNN(nn.Module):
    def __init__(self, img_height, input_channel, n_class, hidden_size):
        super().__init__()

        if img_height % 16 != 0:
            raise ValueError('img_height has to be a multiple of 16')

        kernel_size = [3, 3, 3, 3, 3, 3, 2]
        padding_size = [1, 1, 1, 1, 1, 1, 0]
        stride = [1, 1, 1, 1, 1, 1, 1]
        channel = [64, 128, 256, 256, 512, 512, 512]

        def conv_relu(i, batchNormalization=False):
            in_channels = input_channel if i == 0 else channel[i - 1]
            out_channels = channel[i]
            cnn.add_module(f'conv{i}',
                           nn.Conv2d(in_channels, out_channels,
                                     kernel_size[i],
                                     stride[i],
                                     padding_size[i]))

            if batchNormalization:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(out_channels))
            cnn.add_module(f'relu{i}', nn.ReLU(True))

        # x: 1 x 32 x 320
        cnn = nn.Sequential()
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(2, 2))  # 64x16x160

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(2, 2))  # 128x8x80

        conv_relu(2, True)
        conv_relu(3)
        cnn.add_module('pooling2',
                       nn.MaxPool2d(kernel_size=(2, 2),
                                    stride=(2, 1),
                                    padding=(0, 1)))  # 256x4x81

        conv_relu(4, True)
        conv_relu(5)
        cnn.add_module('pooling3',
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x82
        conv_relu(6, True)  # 512x1x81

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, n_class)
        )

    def forward(self, x):
        cnn_feature = self.cnn(x)

        # 1 x 512 x 1 x 81
        h = cnn_feature.size()[2]
        if h != 1:
            raise ValueError("the height of cnn_feature must be 1")

        cnn_feature = cnn_feature.squeeze(2)

        # 81: 序列长度 1: batch size, 512: 每个特征的维度
        cnn_feature = cnn_feature.permute(2, 0, 1)

        output = self.rnn(cnn_feature)
        # [81, 1, 26]
        x = F.log_softmax(x, dim=2)
        return output


class BidirectionalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, out_feature):
        super().__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
        self.embedding = nn.Linear(hidden_size * 2, out_feature)

    def forward(self, x):
        # x: [81, 1, 512] → [sequence_length, batch_size, input_size]
        recurrent, _ = self.rnn(x)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output


if __name__ == '__main__':
    img = torch.randn((1, 1, 32, 320))

    crnn = CRNN(32, 1, 26, 256)

    res = crnn(img)

    print(res.shape)

你可能感兴趣的:(深度学习,深度学习,lstm,rnn)