[pytorch] PyTorch—torchvision.models导入预训练模型—残差网络代码讲解

参考:https://blog.csdn.net/wsp_1138886114/article/details/83787181

PyTorch框架中torchvision模块下有:torchvision.datasets、torchvision.models、torchvision.transforms这3个子包。
关于详情请参考官网: http://pytorch.org/docs/master/torchvision/index.html。
具体代码可以参考github: https://github.com/pytorch/vision/tree/master/torchvision。

torchvision.models

此模块下有常用的 alexnet、densenet、inception、resnet、squeezenet、vgg(关于网络详情请查看)等常用的网络结构,并且提供了预训练模型,我们可以通过简单调用来读取网络结构和预训练模型,同时使用fine tuning(微调)来使用。
关于 fine tuning 可以查看 https://blog.csdn.net/hjxu2016/article/details/78424370
以残残差网路为例来讲解。

残差网络代码详解

ResNet主要有五种变形:Res18,Res34,Res50,Res101,Res152。

每个网络都包括三个主要部分:输入部分、输出部分和中间卷积部分(中间卷积部分包括如图所示的Stage1到Stage4共计四个stage)。尽管ResNet的变种形式丰富,但是都遵循上述的结构特点,网络之间的不同主要在于中间卷积部分的block参数和个数存在差异。[pytorch] PyTorch—torchvision.models导入预训练模型—残差网络代码讲解_第1张图片

具体代码参考github:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
论文连接:https://arxiv.org/abs/1512.03385

1. 模块调用

import torchvision

"""
    如果你需要用预训练模型,设置pretrained=True
    如果你不需要用预训练模型,设置pretrained=False,默认是False,你可以不写
"""
model = torchvision.models.resnet50(pretrained=True) 
model = torchvision.models.resnet50() 

# 你也可以导入densenet模型。且不需要是预训练的模型
model = torchvision.models.densenet169(pretrained=False)

2. 源码解析

以导入resnet50为例,介绍具体导入模型时候的源码。
运行 model = torchvision.models.resnet50(pretrained=True)的时候,是通过models包下的resnet.py脚本进行的,源码如下:

首先是导入必要的库,其中model_zoo是和导入预训练模型相关的包,另外all变量定义了可以从外部import的函数名或类名。这也是前面为什么可以用torchvision.models.resnet50()来调用的原因。
model_urls这个字典是预训练模型的下载地址。

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

接下来就是resnet50这个函数了,参数pretrained默认是False。

  1. model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)是构建网络结构,Bottleneck是另外一个构建bottleneck的类,在ResNet网络结构的构建中有很多重复的子结构,这些子结构就是通过Bottleneck类来构建的,后面会介绍。
  2. 如果参数pretrained是True,那么就会通过model_zoo.py中的load_url函数根据model_urls字典下载或导入相应的预训练模型。
  3. 通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)。
def resnet50(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model

其他resnet18、resnet101等函数和resnet50基本类似。

差别主要是在:
1、构建网络结构的时候block的参数不一样,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。
2、调用的block类不一样,比如在resnet50、resnet101、resnet152中调用的是Bottleneck类,而在resnet18和resnet34中调用的是BasicBlock类,这两个类的区别主要是在residual结果中卷积层的数量不同,这个是和网络结构相关的,后面会详细介绍。
3、如果下载预训练模型的话,model_urls字典的键不一样,对应不同的预训练模型。因此接下来分别看看如何构建网络结构和如何导入预训练模型。

# pretrained (bool): If True, returns a model pre-trained on ImageNet

def resnet18(pretrained=False, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

def resnet101(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model

你可能感兴趣的:(pytorch,深度学习)