pytorch版crnn网络框架

这是我们常见的pytorch版的crnn网络框架,我遵循老师的想法,用C++和libtorch也搭建了一个一模一样的框架,主要是用来部署用的,我暂时把libtorch版本的代码放到了私密文章里面,还没有公开,有需要的私信我

import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class BidirectionalLSTM(nn.Module):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)#512-nh256 256-256
        #1、LSTM(512,256,True)2、LSTM(256, 256, bidirectional=True)

        self.embedding = nn.Linear(nHidden * 2, nOut)#1、Linear(in_features=512, out_features=256, bias=True)2、Linear(in_features=512, out_features=6736, bias=True)

    def forward(self, input):
        recurrent, _ = self.rnn(input)

        T, b, h = recurrent.size() #获取张量的形状大小

        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut] [142,256] [142,6736]

        output = output.view(T, b, -1)

        return output

class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) #输入输出卷积核大小步长填充
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0) #调用函数
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1) #
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True) #
        convRelu(3) #
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input): #前向传播函数

        # conv features
        conv = self.cnn(input)

        b, c, h, w = conv.size()

        #print(conv.size())
        #assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2) # b *512 * width,对数据的维度进行压缩

        conv = conv.permute(2, 0, 1)  # [w, b, c]

        output = F.log_softmax(self.rnn(conv), dim=2)

        return output

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def get_crnn(config):

    model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
    print("model--->",model)
    model.apply(weights_init)

    return model

你可能感兴趣的:(pytorch学习笔记,pytorch,crnn)