PyTorch学习—8.模型创建步骤与nn.Module属性

文章目录

      • 引言
      • 一、模型的创建
        • 1.nn.Module

引言

  这一节,我们开始讲解模型模块。
PyTorch学习—8.模型创建步骤与nn.Module属性_第1张图片

一、模型的创建

  模型的构建有两个要素:
PyTorch学习—8.模型创建步骤与nn.Module属性_第2张图片
下面我们以LeNet模型为例,展示其模型创建过程

class LeNet(nn.Module):
	# 初始化构建子模块
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)
	
	# 拼接子模块
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
        
	# 权值的初始化
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 0.1)
                m.bias.data.zero_()

但是我们什么时候实现模型的拼接与前向传播呢?
LeNet模型继承于Module,Module类中有__call__函数,__call__函数表明这一实例是可以像函数一样被调用的,__call__函数中会调用上面定义好的forword前向传播函数。

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
        	# 前向传播
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

1.nn.Module

  在模型模块有一个非常重要的概念是nn.Module,所有的模型、所有的网络层都是继承于nn.Module类的。下面我们先介绍一下torch.nn
PyTorch学习—8.模型创建步骤与nn.Module属性_第3张图片
这一节我们的重点是nn.Modulenn.Module的属性如下:

  • parameters:存储管理nn.Parameter类
    比如:权值、偏置等这些参数
  • modules :存储管理nn.Module类
    比如:在LeNet模型中的卷积层、池化层
  • buffers:存储管理缓冲属性,如BN层中的running_mean
  • ***_hooks:存储管理钩子函数
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._state_dict_hooks = OrderedDict()
    self._load_state_dict_pre_hooks = OrderedDict()
    self._modules = OrderedDict()
    

注:在Module模块中有一个机制:拦截所有类属性赋值语句,会跳转到Module中的__setattr__函数

    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)

这个函数的主要作用是:对value的数据类型进行判断,

  • 判断是否为Parameters属性,如果是的话,就存储到register_parameter字典中
  • 判断是否为Module属性,如果是的话,就存储到modules字典中

nn.Module总结:

  • 一个module可以包含多个子module
  • 一个module相当于一个运算,必须实现forward()函数
  • 每个module都有8个字典管理它的属性

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
在这里插入图片描述


你可能感兴趣的:(PyTorch框架学习,模型创建,nn.Module)