pytorch 删改pre-trained model

一、pytorch中的pre-train模型

卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络。

pytorch中自带几种常用的深度学习网络预训练模型,如VGG、ResNet等。往往为了加快学习的进度,在训练的初期我们直接加载pre-train模型中预先训练好的参数,model的加载如下所示:

    import torchvision.models as models
    
    # Res-Net
    model = models.ResNet(pretrained=True)
    model = models.resnet18(pretrained=True)
    model = models.resnet34(pretrained=True)
    model = models.resnet50(pretrained=True)

    # VGG
    model = models.VGG(pretrained=True)
    model = models.vgg11(pretrained=True)
    model = models.vgg16(pretrained=True)
    model = models.vgg16_bn(pretrained=True)
--------------------------------------------------------
    # Save/Load Entire Model
    torch.save(model, PATH)
    # Load
    model = torch.load(PATH)
    model.eval()

--------------------------------------------------------
    # Saving & Loading a General Checkpoint for Inference and/or Resuming Training
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
    # Load   
    model = TheModelClass(*args, **kwargs)
    optimizer = TheOptimizerClass(*args, **kwargs)

    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    model.eval()
    # - or -
    model.train()
-----------------------------------------------------------
    # Save:
    torch.save(model.state_dict(), PATH)
    # Load:
    model = Model_Def()
    model.load_state_dict(torch.load(PATH, map_location=device))
    # Load on GPU:
    device = torch.device("cuda")
    model = Model_Def()
    model.load_state_dict(torch.load(PATH))
    model.to(device)
-------------------------------------------------------
    # Save multiple Models in One file
    torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)
    # Load
    modelA = TheModelAClass(*args, **kwargs)
    modelB = TheModelBClass(*args, **kwargs)
    optimizerA = TheOptimizerAClass(*args, **kwargs)
    optimizerB = TheOptimizerBClass(*args, **kwargs)

    checkpoint = torch.load(PATH)
    modelA.load_state_dict(checkpoint['modelA_state_dict'])
    modelB.load_state_dict(checkpoint['modelB_state_dict'])
    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

    modelA.eval()
    modelB.eval()
    # - or -
    modelA.train()
    modelB.train()

 

二、预训练模型的修改

1.参数修改

对于简单的参数修改,这里以resnet预训练模型举例,resnet源代码在Github点击打开链接。

resnet网络最后一层分类层fc是对1000种类型进行划分,对于自己的数据集,如果只有9类,修改的代码如下:

    # coding=UTF-8
    import torchvision.models as models

    #调用模型
    model = models.resnet50(pretrained=True)

    #提取fc层中固定的参数
    fc_features = model.fc.in_features

    #修改类别为9
    model.fc = nn.Linear(fc_features, 9)

 

2.增减卷积层

前一种方法只适用于简单的参数修改,有的时候我们往往要修改网络中的层次结构,这时只能用参数覆盖的方法,即自己先定义一个类似的网络,再将预训练中的参数提取到自己的网络中来。这里以resnet预训练模型举例。

    # coding=UTF-8
    import torchvision.models as models
    import torch
    import torch.nn as nn
    import math
    import torch.utils.model_zoo as model_zoo


    class CNN(nn.Module):
        def __init__(self, block, layers, num_classes=9):
            self.inplanes = 64
            super(ResNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            self.layer1 = self._make_layer(block, 64, layers[0])
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
            self.avgpool = nn.AvgPool2d(7, stride=1)

            #新增一个反卷积层
            self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0, groups=1, bias=False, dilation=1)
            #新增一个最大池化层
            self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
            #去掉原来的fc层,新增一个fclass层
            self.fclass = nn.Linear(2048, num_classes)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
    
        def _make_layer(self, block, planes, blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
    
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes * block.expansion
            for i in range(1, blocks):
                layers.append(block(self.inplanes, planes))
    
            return nn.Sequential(*layers)

        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)

            x = self.avgpool(x)
            #新加层的forward
            x = x.view(x.size(0), -1)
            x = self.convtranspose1(x)
            x = self.maxpool2(x)
            x = x.view(x.size(0), -1)
            x = self.fclass(x)

            return x

    #加载model
    resnet50 = models.resnet50(pretrained=True)
    cnn = CNN(Bottleneck, [3, 4, 6, 3])
    #读取参数
    pretrained_dict = resnet50.state_dict()
    model_dict = cnn.state_dict()
    # 将pretrained_dict里不属于model_dict的键剔除掉
    pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 更新现有的model_dict
    model_dict.update(pretrained_dict)
    # 加载我们真正需要的state_dict
    cnn.load_state_dict(model_dict)
    # print(resnet50)
    print(cnn)

 

你可能感兴趣的:(pytorch 删改pre-trained model)