python出现no module named cv2_pytorch源码阅读(一)Module类

python出现no module named cv2_pytorch源码阅读(一)Module类_第1张图片

pytorch虽然简单易用,但是其高度的封装使得自己在使用中经常出现各种疑惑。本文针对pytorch的核心基类nn.Module【1】进行分析,Module作为各种操作的父类是每个网络定义必须继承的基类。首先来看Module类的类注释:

class 

简单的说,Module类是所有神经网络模块的基类,Module可以以树形式包含别的Module,也就是网络定义中经常使用的子网络嵌套。Module的源码有上千行,包含众多的函数,下面分别分析。

1、初始化函数__init__()

Module类的初始化函数非常简单:

def 

最突出的是OrderedDict的大量使用,OrderedDict是有序的字典,也就是键-值对的插入是按照次序进行的。self._parameters用于存储网络的参数;self._buffers用于存储不需要优化器进行更新的变量;self._backward_hooks和self._forward_hooks分别是前向和反向的钩子,用于获取网络中间层输入输出;self._state_dick_xxx表示状态字典,用于存储和参数值加载等。

2、前向传播forward()

def 

简单的说,forward()函数用于前向传播的定义并且必须实现,否则就会抛出错误。

3、缓存注册register_buffer()

def 

函数注释上面说buffer表示那些不是parameter但是需要存储的变量,比如BN中的mean。不是parameter表示这个量不需要被optimizer更新,也就是不需要训练。函数中大量的if-else主要用于判别输入tensor和name的类型,从而确定是否进行注册。可以在__init__()中使用self.register_buffer()并使用self.named_buffers()查看已经注册的buffer:

class 

可以得到自己注册的buffer:

(

4、注册参数register_parameter()

def 

注册parameter和注册buffer不同,parameter是需要进行训练更新的,注册的parameter和网络定义的卷积和全连接等的weight性质相同,可以通过self.named_parameters()查看,下面定义的网络包含一个Linear层和一个注册的parameter:

class 

将会得到Linear的weight和bias,以及注册的parameter:

(

5、添加新的操作add_module()

def 

这个函数用于给模型添加新的操作,下面采用它给网络添加新的操作,需要注意的是Module采用的是OrderedDict,操作的添加必须是按照顺序的:

class 

得到两个Linear操作:

Net

6、递归施加函数apply()

def 

这个函数是一个递归过程,对当前module的所有第一代子module施加fn函数,这个fn函数是自定义但是并非随意定义,fn中的操作必须是module类所能满足的。源码中给出的是参数初始化的例子,首先需要定义一个用于参数初始化的fn,然后使用module.apply(fn),下面代码中Net类上一节定义的Net():

def 

现在,net的linear的weight就是1000了。

7、cpu和gpu间切换cuda()、cpu()以及to()

def 

这两个函数的功能很明确,就是将数据在cpu和gpu之间切换,类似的还有数据类型的切换float()和double(),由于涉及到底层代码暂不讨论。印象中pytorch的早期版本并没有to(),其作用是根据设备信息动态确定数据位置:

def 

to()中定义了内部函数convert(),特意检查了数据是否为4d形式。

8、模型状态state_dict()与load_state_dict()

def 

模型的state_dict()主要用于参数保存和重载,从state_dict()的源码可以看出它主要是一个递归过程,不断的进行子module的搜索。load_state_dict()源码比较长,在此不列出了。

9、模型参数parameters()与named_parameters()

def 

源码中的“yield”表明它返回的是一个生成器generator,这个generator可以通过迭代的方式获取所有元素,前面的例子已经测试过此函数了,它调用的name_parameters()函数如下:

def 

named_parameters()函数的参数recurse用于确定是否递归,即是否遍历子module的子module,它同样是一个生成器函数。相同作用的函数还有buffers()和named_buffers(),用于返回buffer的两个generator;以及children()和named_children(),用于返回子module;还有modules()和named_modules(),用于递归返回所有子代module。

10、切换训练和测试train()和eval()

def 

训练模式和验证模式针对某些操作是不同的,比如“Dropout”和“BN”等,所以网络需要切换训练和测试模式,train()函数依旧是一个遍历过程,对每个子代module都进行设置。eval()模式仅仅需要将train()函数的mode设置为False即可:

def 

11、清空操作zero_grad()

def 

zeros_grad的目的是针对所有的parameters进行统一的梯度清零操作,依旧是一个遍历过程。

参考:

【1】https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

【2】https://pytorch.org/docs/stable/nn.html#

你可能感兴趣的:(python出现no,module,named,cv2)