街景字符编码识别赛事Task03

之前的baseline中的模型是建立在resnet的预训练模型的基础之上的。

将其中的池化层的属性从平均池化层修改为了自适应池化层

再将最后的全连接修改了 因为此次分类的种类数是11 所以全连接层的输出个数是11

class SVHN_Model1(nn.Module):
    def __init__(self):
        super(SVHN_Model1, self).__init__()
                
        model_conv = models.resnet18(pretrained=True)
        model_conv.avgpool = nn.AdaptiveAvgPool2d(1)
        model_conv = nn.Sequential(*list(model_conv.children())[:-1])
        self.cnn = model_conv
        
        self.fc1 = nn.Linear(512, 11)
        self.fc2 = nn.Linear(512, 11)
        self.fc3 = nn.Linear(512, 11)
        self.fc4 = nn.Linear(512, 11)
        self.fc5 = nn.Linear(512, 11)
    
    def forward(self, img):        
        feat = self.cnn(img)
        # print(feat.shape)
        feat = feat.view(feat.shape[0], -1)
        c1 = self.fc1(feat)
        c2 = self.fc2(feat)
        c3 = self.fc3(feat)
        c4 = self.fc4(feat)
        c5 = self.fc5(feat)
        return c1, c2, c3, c4, c5

我们可以自己搭建卷积层 不适用卷积网络 不使用预训练模型 也可以在预训练的基础上再多加几层。

目前搭建网络层这一块还没有弄好 回头弄好了再进行更新。

你可能感兴趣的:(街景字符编码识别赛事Task03)