PyTorch 源码分析:Module类

目录

1.前言

2. __setattr__魔法方法

3. 遍历Module

4.遍历parameter

5. 遍历Module并进行某种操作(apply)

6. 遍历parameter并进行某种操作(_apply)

总结


1.前言

虽然说是源码分析,但其实主要还是分析Module的的数据结构,即一些比较重要的成员变量和成员函数。同时也不会涉及Pytorch的C++部分。

Module是Pytorch中一个比较重要的类,我们平常在用Pytorch搭模型的时候就需要继承Module类,也经常会用到Module类的一些方法,比如to,cuda,parameters等。所以我希望能通过对Module类进行数据结构分析,更了解这个类。

在我的理解中,Module类完成了对计算函数的一种封装。那么对于Module而言,比较重要的部分就是(1)计算函数自带的参数(parameters,buffer)、(2)真正的计算函数。

另一方面Module可以嵌套地封装Module,所以内部还需要以Module为单位进行管理。

后面我们可以看到整个Module类的数据结构和功能也是围绕这两个部分而展开的。

2. __setattr__魔法方法

__setattr__这个方法会在成员变量被赋值时调用,把成员变量名和值按照key:value的形式存在self.__dict__这个成员变量中,其中成员变量名为string类型。

Module类中重写了这个方法,主要是更改了存储逻辑便于对Module和Prameters进行统一管理,也增加了一些边界条件的判断。

先从两个例子直观理解一下__setattr__的作用。

(1)Module中包含Module

import torch
from torch import nn
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
model = ConvNet()
print(model.__dict__)
"""
输出结果:
{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_modules': OrderedDict([('layer1',
               Sequential(
                 (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
                 (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                 (2): ReLU()
                 (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
               )),
              ('layer2',
               Sequential(
                 (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
                 (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                 (2): ReLU()
                 (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
               )),
              ('fc', Linear(in_features=1568, out_features=10, bias=True))])}
"""

这里可以看到所有继承Module的类(Sequential,Linear等)都以key:value的形式被存储到了_modules这个成员变量字典中。

(2)Module中包含Parameters

fc = nn.Linear(10, 10)
print(fc.__dict__)
"""
输出结果:
{'training': True,
 '_parameters': OrderedDict([('weight', Parameter containing:
               tensor([[ 0.2134, -0.1890,  0.1323, -0.0584, -0.0152,  0.2372,  0.2294, -0.0924,
                        -0.1262,  0.2316],
                       [-0.1862,  0.1274, -0.0426,  0.1631,  0.0892, -0.2445, -0.0770, -0.0417,
                        -0.2140, -0.1274],
                       [-0.2511,  0.2695, -0.2415, -0.0730, -0.1756, -0.2988, -0.2573,  0.0118,
                        -0.1029,  0.1782],
                       [ 0.0450,  0.2580, -0.2233, -0.1633,  0.2082,  0.2528,  0.1202, -0.2490,
                        -0.0706, -0.1986],
                       [ 0.1465,  0.1913, -0.0517,  0.1790, -0.2581, -0.1569, -0.0745,  0.0543,
                        -0.2791, -0.2803],
                       [ 0.1897,  0.1328, -0.1340, -0.2134, -0.0721,  0.1395,  0.1386,  0.0060,
                        -0.2429, -0.2224],
                       [-0.2489, -0.1251, -0.0911,  0.2740,  0.1966, -0.0599,  0.2756, -0.2151,
                        -0.2663,  0.2152],
                       [-0.0975, -0.2635,  0.2904, -0.1424,  0.1922,  0.0336,  0.1513, -0.2895,
                        -0.2024,  0.0007],
                       [-0.1030, -0.2233, -0.1168,  0.0167,  0.0828,  0.1948,  0.1541,  0.2302,
                         0.2033, -0.3030],
                       [-0.1389,  0.2790,  0.1057, -0.2378,  0.2950, -0.1728, -0.1850,  0.2629,
                         0.1562, -0.1782]], requires_grad=True)),
              ('bias',
               Parameter containing:
               tensor([-0.0800, -0.2401,  0.0667, -0.1441, -0.0163, -0.1638, -0.2017, -0.0315,
                       -0.2545,  0.1657], requires_grad=True))]),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_modules': OrderedDict(),
 'in_features': 10,
 'out_features': 10}
"""

和Module类似,所有继承了Parameter类的参数都以key:value的形式被存到了_parameters这个成员变量中。

这里可以明显地感受到,_parameters和_modules是两个比较重要的成员变量,接下来观察Module都提供了哪些方法用于操作_parameters和_modules。

3. 遍历Module

Module类中提供了4种遍历Module的方法,children,named_children,modules,named_modules。

children和named_children的功能是遍历当前Module的子Module。

modules和named_modules则是遍历所有Module。

其中children和modules都只是重复调用了named_children和named_modules。

值得注意的是这4个方法都有yeild关键字,所以它们其实返回的是一个生成器对象。我之前也有对生成器的介绍,对生成器不是很理解的可以先看一下那个部分。

这里以一个例子展示named_children和named_modules的区别。

named_children:

import torch
from torch import nn
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
model = ConvNet()
for name,submodule in model.named_children():
    print("module name:",name)
    print("module:",submodule)
"""
输出结果:
module name: layer1
module: Sequential(
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
module name: layer2
module: Sequential(
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
module name: fc
module: Linear(in_features=1568, out_features=10, bias=True)
"""

named_modules:

for name,submodule in model.named_modules():
    print("module name:",name)
    print("module:",submodule)
"""
输出结果:
module name: 
module: ConvNet(
  (layer1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Linear(in_features=1568, out_features=10, bias=True)
)
module name: layer1
module: Sequential(
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
module name: layer1.0
module: Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
module name: layer1.1
module: BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
module name: layer1.2
module: ReLU()
module name: layer1.3
module: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
module name: layer2
module: Sequential(
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
module name: layer2.0
module: Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
module name: layer2.1
module: BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
module name: layer2.2
module: ReLU()
module name: layer2.3
module: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
module name: fc
module: Linear(in_features=1568, out_features=10, bias=True)
"""

这里比较有意思的是named_modules,可以看到它的遍历顺序应该是深搜,类似于二叉树的先序遍历,也就是先访问根结点,再递归地按顺序访问所有的子结点。

可以看一下它的源码实现:

def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
    if memo is None:
        memo = set()
    if self not in memo:
        if remove_duplicate:
            memo.add(self)
        #访问根节点
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            #递归地调用子结点的named_modules方法
            for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
                yield m

其实也没多少行。首先是访问根结点,之后递归地调用子结点的named_modules。

这里memo的作用是避免重复访问,所有未访问过的结点都会在访问前被放入memo。

prefix就是给name加前缀,所以在没有前缀的情况下,我们可以看到最上面的根节点是没有名字的。

4.遍历parameter

遍历parameter也有两个方法named_parameters和parameters。

这里也只重点介绍named_parameters方法。

named_parameters源码:

def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

这里调用了另一个方法_named_members,这个_named_members的输入是一个lambda匿名函数,输出一个生成器对象。通过遍历这个生成器,可以实现对所有parameter的访问。

我们再看一下这个输入的匿名函数的功能,它的输入是一个module对象,输出这个module对象的

_parameters成员变量的items形式(字典列表化)。

named_members源码:

def _named_members(self, get_members_fn, prefix='', recurse=True):
    r"""Helper method for yielding various names + members of modules."""
    memo = set()
    modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
    for module_prefix, module in modules:
        members = get_members_fn(module)
        for k, v in members:
            if v is None or v in memo:
                continue
            memo.add(v)
            name = module_prefix + ('.' if module_prefix else '') + k
            yield name, v

这个方法的实现和named_modules是差不多的,都是对module进行深搜,但是这里我们只是访问module,还需要访问每个module相应的parameters。所以这个匿名函数的意思就是让你自己定义需要对结点做什么操作,这种做法是比较常见的。

5. 遍历Module并进行某种操作(apply)

module的apply方法和named_members也蛮像的,都是对_modules进行遍历,然后进行你想要的操作,不过apply更通用一些。

apply源码:

def apply(self: T, fn: Callable[['Module'], None]) -> T:
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

主要的区别就是原先先访问根结点,再递归地访问子结点。现在是先递归地访问子结点,最后再访问跟结点。另一点就是不再判断是否有重复访问。

6. 遍历parameter并进行某种操作(_apply)

其实真要遍历parameter的话,用apply也是可以实现的,不过这里还是重新写了一个_apply。

_apply对参数和参数的梯度都用fn处理了一遍。

处理完还需要用compute_should_use_set_data判断一下,是否需要替换原先的参数。

这里我举个例子,model.cuda()将参数从CPU迁移到GPU,那么原先CPU处的内存自然就可以释放掉了。

def _apply(self, fn):
    for module in self.children():
        module._apply(fn)

    def compute_should_use_set_data(tensor, tensor_applied):
        if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
            # If the new tensor has compatible tensor type as the existing tensor,
            # the current behavior is to change the tensor in-place using `.data =`,
            # and the future behavior is to overwrite the existing tensor. However,
            # changing the current behavior is a BC-breaking change, and we want it
            # to happen in future releases. So for now we introduce the
            # `torch.__future__.get_overwrite_module_params_on_conversion()`
            # global flag to let the user control whether they want the future
            # behavior of overwriting the existing tensor or not.
            return not torch.__future__.get_overwrite_module_params_on_conversion()
        else:
            return False

    for key, param in self._parameters.items():
        #操作parameter
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't want to
            # track autograd history of `param_applied`, so we have to use
            # `with torch.no_grad():`
            #对parameter本身执行fn
            with torch.no_grad():
                param_applied = fn(param)
            should_use_set_data = compute_should_use_set_data(param, param_applied)
            if should_use_set_data:
                param.data = param_applied
            else:
                assert isinstance(param, Parameter)
                assert param.is_leaf
                self._parameters[key] = Parameter(param_applied, param.requires_grad)
            #对parameter的梯度执行fn
            if param.grad is not None:
                with torch.no_grad():
                    grad_applied = fn(param.grad)
                should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                if should_use_set_data:
                    param.grad.data = grad_applied
                else:
                    assert param.grad.is_leaf
                    self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
    
    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)

    return self

总结

好吧,虽然说是源码分析,但是确实没涉及到很多源码的部分,但是经过上面的分析,其实也可以发现Module类各种方法的实现在Python层面也没有很复杂,主要还是对子module和parameters做各种操作。在理清了主次之后其实是比较容易能看懂的。

后续的话准备再看Parameter类,还有它的父类Tensor,主要还是Tensor吧。最后再看一下相关的Dataset类。这样对Pytorch的认识就会比较清晰了。

你可能感兴趣的:(pytorch,数据结构,深度学习)