import torch
import torch.nn as nn
from collections.abc import Iterable, Iterator
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self,x):
...
if __name__ == '__main__':
net = Net()
首先Net继承自nn.Module,通过super(python中的超类)完成父类的初始化。
在nn.Module()代码中,完成初始化方法通过python的魔法函数__setattr__完成。简单介绍下该魔法函数:setattr (object,name,value)用于设置当前对象(object)的属性(name)值(value)。当然,name属性不一定存在。简单有个概念即可,不理解魔法函数也没关系。
看下nn.Module中初始化部分的代码。
class Module:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
# 初始化的属性
self.training = True
self._parameters = OrderedDict() # 存储模型参数,参与BP
self._buffers = OrderedDict() # 中间变量,不参与BP
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict() # 添加模块:比如conv/bn等
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value) #注册进self._parameters
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value) #注册进self._parameters
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value # 注册进modules
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value # 注册进modules
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value # 注册进buffers
else:
object.__setattr__(self, name, value)
这里简单介绍下流程:以完成“self._parameters()属性”注册为例。当程序遇见self._parameters() 属性时,将自动执行def setattr(self,name,value)函数:此处self表示当前类本身。在初始化中,name = self._parameters(); value =OrderDict(); 之后在setattr中,会依次判断当前属性name是否是Parameter类/Module类/buffer。由于当前OrderDict都不是,因此,执行最底下代码:object.setattr(self,name,value)完成属性 self._parameters()的添加。
由于本代码中初始化均是OrderDict(),因此,实质上在初始化这个极简网络时,均是执行了 最后一行代码,即object.setattr(self,name,value)。
在上一节中主要介绍了通过super借助__setattr__完成了怎样的初始化。假如现在添加一个卷积层,初始化部分会有何不同呢?
import torch
import torch.nn as nn
from collections.abc import Iterable, Iterator
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lele= nn.Conv2d(1, 1, 1, 1, 0)
def forward(self,x):
...
if __name__ == '__main__':
net = Net()
在第1小节中,实质上执行完了super(Net,self).init()语句。现在程序添加了一个卷积层self.lele=nn.Conv2d(1,1,1,1,0):当然,本节我们只关注Module的学习,并不关注卷积层是如何初始化的,即Conv2d类如何创建本文不讨论。本文默认已经实例了一个self.lele对象。现在核心是:如何将这个self.lele对象添加进Module初始化中呢?由于往类中添加了新的属性,因此,首先执行__setattr__函数。object = self , name=lele, value= nn.Conv2d类。
观察此时函数内部的运行情况。
这里简单介绍下nn.Conv2d类。在pytorch中,Conv2d类继承自_ConvNd类。而_ConvNd类继承自nn.Module(); _ConvNd主要完成卷积核的一些初始化参数工作。而具体到Conv2d/Conv1d类时,主要负责前向传播的运算任务。
Okay,言归正传。在执行到self.lele=nn.Conv2d(1,1,1,1,0)这句代码时:首先肯定完成卷积初始化工作,即完成_ConvNd中参数工作。以Conv2d中in_channel参数为例,首先也是通过__setattr__方法判断in_channel是否是Parameter/Module/Buffer,肯定均不属于。因此,同样也是执行object.setattr(self,name,value)这句代码。即将in_channel这种参数往卷积核类中添加了in_channel属性。
当然,_ConvNd中还有好多参数:同样采用上段中的方法完成了参数注册:
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t,
padding: _size_1_t,
dilation: _size_1_t,
transposed: bool,
output_padding: _size_1_t,
groups: int,
bias: Optional[Tensor],
padding_mode: str) -> None:
Conv2d中有weights(卷积核权重)和bias(卷积核偏置)两个Parameter类,即Conv2d中待学习的参数有两个weights和bias。在torch源码中长这样:
那么,如何将这两个Parameter类注册进self._parameters()?
同理:在__setattr__函数中:
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
...
if isinstance(value, Parameter): #判断若value是Parameter类
self.register_parameter(name, value) # 则注册进self._parameter
简单说就是判断下value的类型,若为Parameter,则注册进_parameters()。这里在看下register_parameter()函数代码:
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
r"""Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (string): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter): parameter to be added to the module.
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("parameter name can't contain \".\"")
elif name == '':
raise KeyError("parameter name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
从上述代码可以看出:实质上就是判断Parameter是否符合各种规范,比如命名啥的。若全都符合,则执行最后一句代码,完成注册。
当然:理解完上述完全ok。还有一个细节涉及到python的又一个高级操作:就是标黄的部分:hasattr,python用来判断对象是否具有某个属性。在pytorch中重写了这个魔法方法:即hasattr实质上跳转到了__getattr__中。看下__getattr__(object_name,name)—用来获取对象的属性:
def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
逻辑也比较简单:若具有这个属性,则通过字典进行返回:_parameters[name]。通过上述方式,就完成了Conv2d中weights和bias两个Parameter类的注册,同时也完成了Conv2d中所有参数的注册。
上述实质上仅仅完成了Conv2d中参数的初始化,现在,实质上实例化出了一个名为lele的卷积核对象,属于一个Module类。因此,还需要通过__setattr__将lele这个Module类注册进self._modules()。具体通过:
def add_module(self, name: str, module: Optional['Module']) -> None:
r"""Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (string): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
"""
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
torch.typename(module)))
elif not isinstance(name, torch._six.string_classes):
raise TypeError("module name should be a string. Got {}".format(
torch.typename(name)))
elif hasattr(self, name) and name not in self._modules:
raise KeyError("attribute '{}' already exists".format(name))
elif '.' in name:
raise KeyError("module name can't contain \".\"")
elif name == '':
raise KeyError("module name can't be empty string \"\"")
self._modules[name] = module
类似于注册register_parameter,此出不在赘述。
对于lele来说:因为lele从nn.Conv2d实例化的对象而来。所以,lele作为一个module(容器)包裹了待优化的参数(weights,bias)。即可以通过module._parameters查看待优化的参数。
上述其实已经完成了一个网络的构建,只要实现forward方法,就能运行了。但是在现有大多数pytorch参考书中:会告诉你诸如“调用.children()方法”。我就奇了怪了,咋调用的?因此,本节主要分析nn.Module中如何实现查看一个网络的module,parameter,buffer等。
截取nn.Module中查看module等属性的代码:还是简单网络为例。
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
for name, param in self.named_parameters(recurse=recurse):
yield param
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
上述代码用来查看parameters。通过yield返回了一个生成器。生成器不理解请看另一篇博客。就是打印每个模块的参数。
同时,nn.Module还提供了查看其它参数的方法,代码类似。
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
for name, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def children(self) -> Iterator['Module']:
for name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
def modules(self) -> Iterator['Module']:
for name, module in self.named_modules():
yield module
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
上述全部都是生成器。下面我给出一个如何使用这些方法的示例代码(写了我好久,太麻烦了。。。):
# -*- coding: utf-8 -*-
# ======================================================
# @Time : 2021/01/04
# @Author : lele wu
# @Email : [email protected]
# @File : lele_module.py
# @Comment: 研究nn.Module属性和方法
# ======================================================
import torch
import torch.nn as nn
from collections.abc import Iterable, Iterator
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lele = nn.Conv2d(1, 1, 1, 1, 0)
#self.bn = nn.BatchNorm2d(1)
#self.relu = nn.ReLU(inplace=True)
def forward(self,x):
#out = self.conv(x)
#out = self.bn(out)
#out = self.relu(out)
#return out
...
if __name__ == '__main__':
# input = torch.Tensor([[[[-1,0],[1,2]]]])
net = Net()
# 查看网络的子模块
print('net.children返回一个迭代器吗?:', isinstance(net.children(), Iterator)) # 因为使用了yeild生成器表达式,故是生成器
print('***********************************************************************************************************')
print('通过调用children()方法:')
for module in net.children():
print('当前module类为:',module)
print('***********************************************************************************************************')
print('通过调用named_children()方法:')
for name,module in net.named_children():
print('类名:',name,'运算模块:',module) # name: 就是自己定义网络中 conv,bn,bn
print('***********************************************************************************************************')
print('通过调用modules()方法:')
for module in net.modules():
print('当前module类为:',module)
print('***********************************************************************************************************')
print('通过调用named_modules()方法:')
for module in net.named_modules():
print('当前module类为:',module)
print('***********************************************************************************************************')
print('总结: named_modules或者modules方法递归调用所有子类;而children和named_children方法则仅调用第一层子类:\n'
'参考资料: https://zhuanlan.zhihu.com/p/65105409?utm_source=wechat_session')
print('***********************************************************************************************************')
print('通过调用buffers()方法:')
for buffer in net.buffers():
print('当前buffer为:',buffer)
print('***********************************************************************************************************')
print('通过调用named_buffers()方法:')
for buffer in net.named_buffers():
print('当前buffer(元祖)为:',buffer)
print('***********************************************************************************************************')
此处我放张运行效果图得了:
当然,由于仅仅有一个卷积核。因此buffers为空,即卷积核在训练过程中不产生缓存。
这是第一篇,写的有点儿多。后续会介绍hook,apply等。当然,下节会首先介绍如何冻结参数。