PyTorch搭建VGG网络

1. VGG Net 网络结构

PyTorch搭建VGG网络_第1张图片
PyTorch搭建VGG网络_第2张图片

2. 搭建过程(以VGG19为例)

(1)加载必要及准备工作

import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch
from torch.nn import functional as F

#对外暴露接口,当导入该模块时,只导入枚举列表里面的
__all__ = ['vgg19']
#预训练模型下载地址
model_urls = {
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}

# 可以定义不同的vgg结构,这样写可以有效节约代码空间,下面以vgg19为例
cfg = {
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
}

基于PyTorch的一些预训练模型下载地址如下:https://github.com/pytorch/vision/tree/master/torchvision/models

(2)构建模型

class VGG(nn.Module):
#nn.Module是一个特殊的nn模块,加载nn.Module,这是为了继承父类
    def __init__(self, features):
        super(VGG, self).__init__()
        #super 加载父类中的__init__()函数
        self.features = features
        self.reg_layer = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = F.upsample_bilinear(x, scale_factor=2)
        x = self.reg_layer(x)
        return torch.abs(x)

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

torch.nn.Module是所有网络的基类,Modules也可以包含其它Modules,允许使用树结构嵌入他们。可以将子模块赋值给模型属性;

torch.nn.Sequential(*args)是一个时序容器,会以他们传入的顺序被添加到容器中;

forward(*input)定义了每次执行的计算步骤。在所有的子类中都需要重写这个函数;

class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False) 对于输入信号的输入通道,提供2维最大池化(max pooling)操作;

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)二维卷积层;

class torch.nn.ReLU(inplace=False) 对输入运用修正线性单元函数 R e L U ( x ) = m a x ( 0 , x ) {ReLU}(x)= max(0, x) ReLU(x)=max(0,x),inplace-选择是否进行覆盖运算。

(3)检验模型

def vgg19():
    """VGG 19-layer model (configuration "E")
        model pre-trained on ImageNet
    """
    model = VGG(make_layers(cfg['E']))
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']), strict=False)
    return model

torch.utils.model_zoo.load_url(url, model_dir=None) 在给定URL上加载Torch序列化对象。如果对象已经存在于 model_dir 中,则将被反序列化并返回。
url (string) - 要下载对象的URL;
model_dir (string, optional) - 保存对象的目录。

load_state_dict(state_dict) 将state_dict中的parameters和buffers复制到此module和它的后代中。state_dict中的key必须和 model.state_dict()返回的key一致。 NOTE:用来加载模型参数。
state_dict (dict) – 保存parameters和persistent buffers的字典。

你可能感兴趣的:(pytorch学习)