Pytorch中加载预训练模型以及冻结层

一、加载预训练模型

加载方式有两种,主要是第二种对于模型finetune比较常用

1、加载框架已有的模型(如resnet等)

代码如下:

import torch
import torch.nn as nn
from torch.utils import model_zoo
import torchvision.models as models

model = models.resnet18()
    model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
    }
model.load_state_dict(model_zoo.load_url(model_urls['resnet18]), strict=False)
# 其中主要是strict=False,假设你针对原resnet18模型添加了自己的层,那么这个strict=False就会只加载name相同的参数

2、加载预训练好的模型

代码如下:

model = EfficientNet1()
# model.load_state_dict(model_zoo.load_url('https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth'), strict=False)
model_dict = model.state_dict()
sd = torch.load('/root/.cache/torch/checkpoints/efficientnet-b0-355c32eb.pth')
pretrained_dict = {k:v for k, v in sd.items() if k in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

3、冻结某些层

代码如下:

 # 这里freeze了layer1之前的层(包括layer1),以及所有的bn层
 for k,v in model.named_parameters():
     if k.startswith('conv1') or k.startswith('layer1'):
         v.requires_grad = False
 for m in model.modules():
     if isinstance(m, nn.BatchNorm2d):
         m.eval()
         m.weight.requires_grad = False
         m.bias.requires_grad = False

你可能感兴趣的:(Pytorch中加载预训练模型以及冻结层)