动态导入模块,加载预训练模型,nn.Sequential函数里面必须是a Module subclass,不能是一个列表或者是其他的迭代器、生成器,虽然这里面包含了Module的子类

如果想看该sequence(假设sequence = nn.Sequential(......),直接打印该网络就可以:

print(sequence)

如果RES是一个model也可以直接打印来查看里面的结果:print(RES)

class RES(nn.Module):
    def __init__(self):
        super(RES, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1=nn.BatchNorm2d(64)
        self.relu=nn.ReLU(inplace=True)
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn2=nn.BatchNorm2d(128)
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv2(x)
        x=self.bn2(x)
        return x

model=RES()
glb = nn.Sequential(*list(model.children())[:4])

有两点数据的说明:这个类继承了Module一定要用super函数

nn.Sequential函数里面的参数一定是Module的子类,而list:list is not a Module subclass。所以不能当做参数,当然model.children()也是一样:Module.children is not a Module subclass。这里的*就起了作用,将list或者children的内容迭代的一个一个的传进去,效果如下:

动态导入模块,加载预训练模型,nn.Sequential函数里面必须是a Module subclass,不能是一个列表或者是其他的迭代器、生成器,虽然这里面包含了Module的子类_第1张图片

 当然,我们还可以像最上面的那样,选取里面的几个Module,例如[:4]也就是第0个到第3个.

 

动态导入模块,使用importlib.import_module函数实际上是import了一个叫做resnet的文件,下面的语句相当于 import xxx as resnet

当然这里的xxx是该文件的实际路径

import importlib
resnet = importlib.import_module("torchvision.models.resnet")
resnet18=resnet.resnet18()
resnet34=resnet.resnet34()
resnet50=resnet.resnet50()
resnet101=resnet.resnet101()
resnet152=resnet.resnet152()

动态导入模块,加载预训练模型,nn.Sequential函数里面必须是a Module subclass,不能是一个列表或者是其他的迭代器、生成器,虽然这里面包含了Module的子类_第2张图片

其他的模块有:

"""
alexnet文件
"""
alexnet=importlib.import_module("torchvision.models.alexnet")
alexnet=alexnet.alexnet()
nn.Sequential(*alexnet.children())

"""
vgg文件
"""
vgg=importlib.import_module("torchvision.models.vgg")
vgg16=vgg.vgg16() # vgg11=vgg.vgg11(),vgg19=vgg.vgg19(),vgg13=vgg.vgg13()以及他们的bn形式
# vgg16_bn=vgg.vgg16_bn(),vgg11_bn=vgg.vgg11_bn(),vgg19_bn=vgg.vgg19_bn(),vgg13_bn=vgg.vgg13_bn()
nn.Sequential(*vgg16.children())

"""
densenet文件
"""
densenet=importlib.import_module("torchvision.models.densenet")
densenet121=densenet.densenet121() 
# densenet169=densenet.densenet169(),densenet201=densenet.densenet201(),densenet161=densenet.densenet161()
nn.Sequential(*densenet121.children())

"""
inception文件
"""
inception=importlib.import_module("torchvision.models.inception")
inception_v3=inception.inception_v3()
nn.Sequential(*inception_v3.children())

"""
squeezenet文件
"""
squeezenet=importlib.import_module("torchvision.models.squeezenet")
squeezenet1_0=inception.squeezenet1_0()
# squeezenet1_0=inception.squeezenet1_1()
nn.Sequential(*squeezenet1_0.children())

还有一种导入方式,是比较常用的,推荐的:

import torchvision.models as models
models.squeezenet1_0()

"""
models后面直接接的是网络
models的__init__文件如下
"""
from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
"""
可以看出来,导入的是这5个文件里面的函数(类)
*代表想对应文件的__all__,下面是各个文件的该属性以及训练好的权重
"""
# alexnet
__all__ = ['AlexNet', 'alexnet']
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
# resnet
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
# vgg
__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',]
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
# squeezenet
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}
# inception
__all__ = ['Inception3', 'inception_v3']
model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
# densenet
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}

所有的模型默认都是不加载预训练模型参数的,怎么加载预训练模型参数呢?很简单,就在括号里面的pretrained设置成True,如果仅仅是需要该结构而不需要预训练模型参数作为初始化,那么pretrained=False。

resnet50 = models.resnet50(pretrained=True)

推荐!这里有一篇比较综合https://blog.csdn.net/weixin_41278720/article/details/80759933

其中可以补充一点就是将参数进行下载,相比加载模型来说更加的节省资源

    import torch.utils.model_zoo as model_zoo

    def _load_pretrained_model(self):
        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
                                           '/home/zzp/SSD_ping/my-root-path/My-core-python/PretrainedWeights')
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

 

你可能感兴趣的:(pytorch)