这篇博客介绍torchvision.models .torchvision.models这个包中包含alexnet,densenet,inception,resnet,squeezenet,vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
import torchvision
model = torchvision.models.resnet50(pretrained=True)
model = torchvision.models.resnet50(pretrained=False)
model = torchvision.models.densenet169(pretrained=False)
model = torchvision.models.densenet169()
接下来以导入resnet50介绍为例具体导入模型时候的源码运行。model = torchvision.models.resnet50(pretrained=True)
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
model_urls = {
'resnet18': '',
'resnet34': '',
'resnet50': '',
'resnet101': '',
'resnet152': '',
接下来就是resnet50这个函数了,参数预训练默认是假。首先model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
return model
其他resnet18,resnet101等函数和resnet50基本类似,差别主要是在:1,构建网络结构的时候块的参数不一样,比如resnet18中是[2,2,2,2],resnet101中是[3,4 ,23,3] .2,调用的块类不一样,比如在resnet50,resnet101,resnet152中调用的是瓶颈类,而在resnet18和resnet34中调用的是BasicBlock类,这两个类的区别主要是在残余结果中卷积层的数量不同,这个是和网络结构相关的,后面会详细介绍0.3,如果下载预训练模型的话,model_urls字典的键不一样,对应不同的预训练模型。因此接下来分别看看如何构建网络结构和如何导入预训练模型。
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
return model
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
和forward方法。从前进方法可以看出,瓶颈就是我们熟悉的3个主要的卷积层,BN层和激活层,最后的out + =残余就是元素添加的操作。
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
BasicBlock类和瓶颈类类似,前者主要是用来构建ResNet18和ResNet34网络,因为这两个网络的残余结构只包含两个卷积层,没有瓶颈类中的瓶颈概念。因此在该类中,第一个卷积层采用的是kernel_size = 3的卷积,如conv3x3函数所示。
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
介绍完如何构建网络,接下来就是如何获取预训练模型。前面提到这一行代码:if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
。load_url函数源码如下首先model_dir是下载下来的模型的保存地址,如果没有指定的话就会保存在项目如果不是os.path.exists(cached_file)语句用来判断是否指定目录下已经存在要下载存在,就直接调用torch.load接口导入模型,如果不存在,则从网上下载,下载是通过_download_url_to_file(url,cached_file,hash_prefix,progress = progress)进行的,不再细讲。重点在于模型导入是通过torch.load()接口来进行的,不管你的模型是从网上下载的还是本 地已有的。
def load_url(url, model_dir=None, map_location=None, progress=True):
r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overriden with the ``$TORCH_MODEL_ZOO`` environment variable.
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr
>>> state_dict = torch.utils.model_zoo.load_url('')
if model_dir is None:
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
if not os.path.exists(model_dir):
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix =
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return torch.load(cached_file, map_location=map_location)
