如果想看该sequence(假设sequence = nn.Sequential(......),直接打印该网络就可以:
print(sequence)
如果RES是一个model也可以直接打印来查看里面的结果:print(RES)
class RES(nn.Module):
def __init__(self):
super(RES, self).__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1=nn.BatchNorm2d(64)
self.relu=nn.ReLU(inplace=True)
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False)
self.bn2=nn.BatchNorm2d(128)
def forward(self,x):
x=self.conv1(x)
x=self.bn1(x)
x=self.relu(x)
x=self.maxpool(x)
x=self.conv2(x)
x=self.bn2(x)
return x
model=RES()
glb = nn.Sequential(*list(model.children())[:4])
有两点数据的说明:这个类继承了Module一定要用super函数
nn.Sequential函数里面的参数一定是Module的子类,而list:list is not a Module subclass。所以不能当做参数,当然model.children()也是一样:Module.children is not a Module subclass。这里的*就起了作用,将list或者children的内容迭代的一个一个的传进去,效果如下:
当然,我们还可以像最上面的那样,选取里面的几个Module,例如[:4]也就是第0个到第3个.
动态导入模块,使用importlib.import_module函数实际上是import了一个叫做resnet的文件,下面的语句相当于 import xxx as resnet
当然这里的xxx是该文件的实际路径
import importlib
resnet = importlib.import_module("torchvision.models.resnet")
resnet18=resnet.resnet18()
resnet34=resnet.resnet34()
resnet50=resnet.resnet50()
resnet101=resnet.resnet101()
resnet152=resnet.resnet152()
其他的模块有:
"""
alexnet文件
"""
alexnet=importlib.import_module("torchvision.models.alexnet")
alexnet=alexnet.alexnet()
nn.Sequential(*alexnet.children())
"""
vgg文件
"""
vgg=importlib.import_module("torchvision.models.vgg")
vgg16=vgg.vgg16() # vgg11=vgg.vgg11(),vgg19=vgg.vgg19(),vgg13=vgg.vgg13()以及他们的bn形式
# vgg16_bn=vgg.vgg16_bn(),vgg11_bn=vgg.vgg11_bn(),vgg19_bn=vgg.vgg19_bn(),vgg13_bn=vgg.vgg13_bn()
nn.Sequential(*vgg16.children())
"""
densenet文件
"""
densenet=importlib.import_module("torchvision.models.densenet")
densenet121=densenet.densenet121()
# densenet169=densenet.densenet169(),densenet201=densenet.densenet201(),densenet161=densenet.densenet161()
nn.Sequential(*densenet121.children())
"""
inception文件
"""
inception=importlib.import_module("torchvision.models.inception")
inception_v3=inception.inception_v3()
nn.Sequential(*inception_v3.children())
"""
squeezenet文件
"""
squeezenet=importlib.import_module("torchvision.models.squeezenet")
squeezenet1_0=inception.squeezenet1_0()
# squeezenet1_0=inception.squeezenet1_1()
nn.Sequential(*squeezenet1_0.children())
import torchvision.models as models
models.squeezenet1_0()
"""
models后面直接接的是网络
models的__init__文件如下
"""
from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
"""
可以看出来,导入的是这5个文件里面的函数(类)
*代表想对应文件的__all__,下面是各个文件的该属性以及训练好的权重
"""
# alexnet
__all__ = ['AlexNet', 'alexnet']
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
# resnet
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
# vgg
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',]
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
# squeezenet
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}
# inception
__all__ = ['Inception3', 'inception_v3']
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
# densenet
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
所有的模型默认都是不加载预训练模型参数的,怎么加载预训练模型参数呢?很简单,就在括号里面的pretrained设置成True,如果仅仅是需要该结构而不需要预训练模型参数作为初始化,那么pretrained=False。
resnet50 = models.resnet50(pretrained=True)
其中可以补充一点就是将参数进行下载,相比加载模型来说更加的节省资源
import torch.utils.model_zoo as model_zoo
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'/home/zzp/SSD_ping/my-root-path/My-core-python/PretrainedWeights')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)