Pytorch学习(2) —— 网络工具箱 TORCH.NN 基本类用法

结合博客《Pytorch学习(1) —— Tensor基础》我们已经了解了Tensor的一些基础知识,下面开始结合这些介绍TORCH.NN里面的一些内容,这些内容包含了构建网络的一些基本类(Module, Prameters等)操作,并对一些关键网络函数(卷积,池化等)进行整理,作为字典方便后续查询。

正常来说,Tensor说完应该介绍torch.autograd自动求导问题,但是在具体应用中,似乎并没有太多的使用,因此这里跳过这个部分(自动求导的简要说明可以查看其它博主的分析),后续遇到了再去了解。PyTorch进一步提供了集成度更高的模块化接口torch.nn, 该接口构建于Autograd之上, 提供了网络模组、 优化器和初始化策略等一系列功能。

网络定义几乎都是通过一个类来实现,在类中才明确用哪种层,因此,这一节给出torch.nn中一些类的用法。

文章目录

  • 1 nn.Parameters 类
  • 2 nn.Module 类
    • 2.1 add_module(name, module)
    • 2.2 apply(fn)
    • 2.3 children() 和 modules()
    • 2.4 train() 和 eval() 问题
    • 2.5 load_state_dict 和 state_dict
  • 3 nn.Sequential 类
  • 4 nn.ModuleList 类
  • 5 nn.ModuleDict 类
  • 6 nn.ParameterList 类
  • 7 nn.ParameterDict 类
  • 总结

1 nn.Parameters 类

Parameters是一种被认为是模型参数的Tensor,这个存的是这个模型需要进行训练的参数。通过查看源代码class Parameter(torch.Tensor),可以发现这个类就是继承于Tensor类,就是一个Tensor的扩展。这个类往往与Module模型类一起使用,当它们被指定为Module属性时,它们会自动添加到其参数列表中。

使用时候仅需要输入两个参数,data:参数Tensor,和requires_grad:是否需要求导,默认为True。

2 nn.Module 类

nn.Module所有神经网络模块的基类,并在类中实现了网络各层的定义及前向计算与反向传播机制。 在实际使用时, 如果想要实现某个神经网络, 需要继承nn.Module。下面是官方给出定义一个模型的例子,forward一定要写出,这个指定了网络的计算方法,下面这种模块里面内部有参数,所以在构造模型时候,参数也被注册进参数列表中。

import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

我们输入model = Model()并使用print输出,可以得到网络模型

Model(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))
)

下面给出Module类里面的各种成员函数的使用方法。

2.1 add_module(name, module)

将子模块添加到当前模块,可以使用给定的名称作为属性访问模块。下面给出一种示例,就是把上面的示例self.conv1和self.conv2的定义方法换一下,发现就是名换了,此外,也可以用name作为变量名,self.name。

self.conv1 = self.add_module('conv_01', nn.Conv2d(1, 20, 5))
self.conv2 = self.add_module('conv_02', nn.Conv2d(20, 20, 5))

输出:
Model(
  (conv_01): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv_02): Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))
)

2.2 apply(fn)

典型用途包括初始化模型的参数,下面给出一种初始化方法。这个函数我觉得在后期网络迁移时候应该常用,比如自己删减掉某些层,使用这个方法可以手动对某些层进行参数初始化,初始化为需要的。

def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        print(m.weight)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

2.3 children() 和 modules()

这个两个的区别我通过查看博客《self.modules() 和 self.children()的区别》才了解。

以下面这个图为例,self.modules()采用深度优先遍历的方式,存储了net的所有模块,包括net itself,net’s children, children of net’s children, 而self.children()存储网络结构的子层模块,也就是net’s children那一层,如果想通过代码深入了解,查看那篇博客即可,写的很清晰啦。

Pytorch学习(2) —— 网络工具箱 TORCH.NN 基本类用法_第1张图片

2.4 train() 和 eval() 问题

此部分的理解我参考了博客《model.train和model.eval用法和区别》

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

简而言之,训练模式下,设置train(True),测试模式下,设置eval(True)

2.5 load_state_dict 和 state_dict

load_state_dict(state_dict, strict=True)是将参数state_dict拷贝到模型中,如果strict为True,则state_dict中的键值一定要与模型中的参数键值匹配,简而言之就是按照参数名进行参数拷贝,返回值是拷贝一些没用的参数或异常参数,如果存在不匹配的则会报错。strict为False时候则忽略这些问题。

state_dict(destination=None, prefix='', keep_vars=False)其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数。官方没给参数啥含义,我一会看源码分析下

下面给出较为直观的几个基本函数。

名称 用法
cpu() 将模型所有参数与缓存移至CPU中
cuda(device=None) 将模型所有参数与缓存移至GPU中,通过device设置GPU类型
double() 将模型所有浮点的参数与缓存转换为double类型
float() 将模型所有浮点的参数与缓存转换为float类型
half() 将模型所有浮点的参数与缓存转换为半精度half类型
type(dst_type) 将所有参数和缓冲区强制转换为dst_type。
zero_grad() 将模型所有参数梯度设置为0
buffers(recurse=True) 返回模型的缓存迭代器
requires_grad_(requires_grad=True) 设置是否需要求导,在实际中有助于冻结某些层不让其进行反向传导,常见于泛化与迁移问题
forward(*input) 定义网络都需要前向传播,这个由用户自己开发,后期针对某些网络就知道这些怎么用了
parameters(recurse=True) 返回模型参数迭代器

下面是目前没有成功研究明白的几个函数,这些函数不常用,防止拖慢进度,先放着,有空研究

  • buffers(recurse=True)
  • dump_patches = FALSE
  • extra_repr()
  • named_buffers(prefix=’’, recurse=True)
  • named_children()
  • named_modules(memo=None, prefix=’’)
  • named_parameters(prefix=’’, recurse=True)
  • register_backward_hook(hook)
  • register_buffer(name, tensor)
  • register_forward_hook(hook)
  • register_forward_pre_hook(hook)
  • register_parameter(name, param)
  • to(*args, **kwargs)

3 nn.Sequential 类

Sequential类继承于Module类,是一个连续的容器,用于快速构建一个模型,返回也是个模型。模块将按在构造函数中传递的顺序添加到其中。为了更容易理解,下面是一个小例子:

# Sequential的一个例子
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# 带有OrderedDict的Sequential例子,注意使用from collections import OrderedDict
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

对上述这两个模型输出,可以发现,第二个方法可以直接对网络结构进行命名。

Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)

4 nn.ModuleList 类

将网络中子模块合成一个list,可以像Python的list一样索引,下面是一个例子

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

下面是这个列表开放的几个函数。

名称 用途
append(module) 在尾部添加一个模型
extend(modules) 在尾部添加多个模型
insert(index, module) 在index位置插入一个模型

5 nn.ModuleDict 类

模型字典,这里面存了一大堆模型信息,与ModuleList相似,这个只是个容器,不涉及前向执行问题,下面给出官方的例子,很显然这个可以看成是一个模型的一个字典容器,在构建模型前向时候,可以根据需要选择不同的网络。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
     
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

下面给出这个类的几个开放函数。

名称 用法
clear() 清空这个模型字典
items() 返回一个键值对,用法与Python的字典用法相同
keys() 返回一个键值迭代器,用法类似for key in P.keys()
pop(key) 删除key键,并返回对应的模型
update(parameters) 更新字典,输入新的键值对
values() 返回字典中值(也就是模型)的迭代器

6 nn.ParameterList 类

这个类实际上是将一个Parameter的List转为ParameterList,如下例所示[nn.Parameter(torch.randn(10, 10)) for i in range(10)]类型是List,List的每个元素是Parameter,然后这个List作为参数传入这个类构造ParameterList类型。

ParameterList输入一定是一个Parameter的List,其他类型会报错,在注册时候就会提示元素不是Parameter类型。

parms = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

下面是官方给的一个示例,就是一个示例,对于需要大规模参数时,一个个手打上去非常复杂,且阅读难,因此使用列表可以快速构建所需参数。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

下面给出这个类的开放函数。

名称 用法
append(parameter) 在列表后面添加一个Parameter类
extend(parameters) 在列表后面添加一个列表

同样的,为了方便用户使用,元素提取赋值访问等于正常的List用法一样,具体用法可以看下一节源码分析就知道了,这里就不细说了,太多了。

7 nn.ParameterDict 类

ParameterDict 是一个字典类源码,与python的字典非常相似,下面就是字典的一个例子,输入参数是个普通字典,然后转换为ParameterDict类型。

params = nn.ParameterDict({
      'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10))})

官方给了这个字典的一个用法,其实就是在运算过程中,根据输入的forward参数,选择不同的参数,目前不需要具体了解这个在网络中有何应用,到时候做项目自然会碰到,需要时候再查用法即可。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterDict({
     
                'left': nn.Parameter(torch.randn(5, 10)),
                'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

下面给出这个类的开放函数。

名称 用法
clear() 清空这个字典
items() 返回一个键值对,用法与Python的字典用法相同
keys() 返回一个键值迭代器,用法类似for key in P.keys()
pop(key) 删除key键,并返回对应的值
update(parameters) 更新字典,输入新的键值对
values() 返回字典值迭代器

ParameterDict 主要就是这些,简单来说,这个就是用于pytorch计算的一个字典类,官方已经将其封装的与python自带的字典类用法差不多了。

总结

总体来说这张补充的是网络一些类的基本用法,但是仍然存在一些无法确定理解的函数,这些需要后续进行细致的分析与研究,后面会不断更新新函数,但是这个部分存在的问题我也会继续深入研究。

你可能感兴趣的:(Pytorch学习,python,深度学习)