Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析

nn.Sequential

源码位置:./torch/nn/modules/container.py

# 查看torch位置

>>> import torch
>>> torch.__file__

# container.py包含内容 
# 都继承自nn.Module类

# class Sequential(Module)

# class ModuleList(Module)
# class ModuleDict(Module)

# class ParameterList(Module)
# class ParameterDict(Module)

1 nn.Sequential用法及源码

函数包括:

Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析_第1张图片

1.1 初始化方法_init_

_init_ 用法:

# 用法1 传入OrderedDict
model = nn.Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))
                
# 用法2 直接添加module
model = nn.Sequential(
    # 二维卷积
	nn.Conv2d(1,20,5),
    nn.ReLU(),
    nn.Conv2d(20,64,5),
    nn.ReLU()
	)

源码:

# 有序 容器
class Sequential(Module):

    _modules: Dict[str, Module]  # type: ignore[assignment]

    @overload
    def __init__(self, *args: Module) -> None:
        ...

    @overload
    def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
        ...

    def __init__(self, *args):
        super(Sequential, self).__init__()
        # 方法1.如果传入参数为OrderedDict类型 则遍历OrderedDict(包含key) 添加模块
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        # 方法2.不是OrderedDict 直接添加模块(模块的key默认是数字0,1,2...)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

1.2 _getitem_,_setitem_,_delitem_

_getitem_
_setitem_
_setitem_

用法:

# getitem
# 1.如果是用OrderedDict创建的 用key获取
>>> model
Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)
>>> model.conv1
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
>>> model.relu1
ReLU()
>>> 

# 2.如果是直接创建的 用index获取
>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> model[0]
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
>>> model[1]
ReLU()
>>> 

源码:

# 调用了_get_item_by_idx函数
def __getitem__(self, idx) -> Union['Sequential', T]:
        # OrderedDict取值方式
        if isinstance(idx, slice):
            return self.__class__(OrderedDict(list(self._modules.items())[idx]))
        # 按传入索引取值
        else:
            return self._get_item_by_idx(self._modules.values(), idx)

# 查找函数 被调用
def _get_item_by_idx(self, iterator, idx) -> T:
    """Get the idx-th item of the iterator"""
    size = len(self)
    idx = operator.index(idx)
    if not -size <= idx < size:
        raise IndexError('index {} is out of range'.format(idx))
    idx %= size
    return next(islice(iterator, idx, None))
# setitem
# 修改已经存在模块的信息 
# 以下代码将原来的0模块卷积核大小由5 × 5改为3 × 3
>>> model                       
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> model[0] = nn.Conv2d(1,20,3)
>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)

# 源码
def __setitem__(self, idx: int, module: Module) -> None:
    key: str = self._get_item_by_idx(self._modules.keys(), idx)
    return setattr(self, key, module)
# delitem 删除现有的module
# 以下代码演示删除原来的0模块
>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> del model[0]
>>> model
Sequential(
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)

# 源码
def __delitem__(self, idx: Union[slice, int]) -> None:
    # 有子模块
    if isinstance(idx, slice):
        for key in list(self._modules.keys())[idx]:
            delattr(self, key)
    else:
        key = self._get_item_by_idx(self._modules.keys(), idx)
        delattr(self, key)

1.3 输出信息 _len_,_dir_

用法:

>>> model
Sequential(
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> len(model)
3
>>> dir(model)
['T_destination', '__annotations__',..., 'training', 'type', 'xpu', 'zero_grad']

源码:

def __len__(self) -> int:
    return len(self._modules)


# 返回当前范围内的变量、方法和定义的类型列表
def __dir__(self):
    keys = super(Sequential, self).__dir__()
    keys = [key for key in keys if not key.isdigit()]
    return keys

1.4 迭代器 _iter_

用法:

>>> model
Sequential(
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> for i in iter(model):
...     print(i)
... 
ReLU()
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
ReLU()
>>> 

源码:

def __iter__(self) -> Iterator[Module]:
    return iter(self._modules.values())

1.5 添加module append方法

用法:

>>> model
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
>>> model.append(nn.Conv2d(1,20,3))
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
  (4): Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1))
)

源码:

def append(self, module: Module) -> 'Sequential':

    self.add_module(str(len(self)), module)
    return self

1.6 前向传播函数 forward方法

用法:
模型会自动调用;

# 前向传播
model(data)

# 而不是使用下面的
# model.forward(data)

源码:

# input依次经过每个module
def forward(self, input):
    for module in self:
        input = module(input)
    return input

2 nn.ModuleList用法及源码

函数们:

Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析_第2张图片

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # nn.ModuleList装了10个nn.Linear
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

这里只介绍insert/append/extend,其他函数与之前相同;

# insert 在任意位置添加
>>> linears = nn.ModuleList([nn.Linear(10, 10) for i in range(5)]) 
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Linear(in_features=10, out_features=10, bias=True)
  (4): Linear(in_features=10, out_features=10, bias=True)
)
>>> linears.insert(3,nn.Conv2d(1,20,5))
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
)

# append 在最后添加
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
)
>>> linears.append(nn.Conv2d(1,20,5))
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): Linear(in_features=10, out_features=10, bias=True)
  (6): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
)

# extend 一次添加多个
>>> linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)]) 
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
)
>>> linears.extend([nn.Conv2d(1,10,5),nn.ReLU()])
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (4): ReLU()
)

源码:

# 在任意位置插入一个模块
def insert(self, index: int, module: Module) -> None:

    for i in range(len(self._modules), index, -1):
        self._modules[str(i)] = self._modules[str(i - 1)]
    self._modules[str(index)] = module


# 在最后添加一个模块
def append(self, module: Module) -> 'ModuleList':
    r"""Appends a given module to the end of the list.

    Args:
        module (nn.Module): module to append
    """
    self.add_module(str(len(self)), module)
    return self

# 在最后添加多个模块
def extend(self, modules: Iterable[Module]) -> 'ModuleList':
    r"""Appends modules from a Python iterable to the end of the list.

    Args:
        modules (iterable): iterable of modules to append
    """
    if not isinstance(modules, container_abcs.Iterable):
        raise TypeError("ModuleList.extend should be called with an "
                        "iterable, but got " + type(modules).__name__)
    offset = len(self)
    for i, module in enumerate(modules):
        self.add_module(str(offset + i), module)
    return self

3 nn.ModuleDict用法及源码

作用和nn.ModuleDict一样,一个是list形式,一个是dict形式;
将list的函数换成dict的函数即可;

函数们:

Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析_第3张图片

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'conv2': nn.Conv2d(10,10,3),
            'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
            ['lrelu', nn.LeakyReLU()],
            ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

函数用法:

# 初始化一个ModuleDict
>>> choices = nn.ModuleDict({
...             'conv': nn.Conv2d(10, 10, 3),
...             'conv2': nn.Conv2d(10,10,3),
...             'pool': nn.MaxPool2d(3)
...         })
>>> 
>>> choices
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
# 查看
>>> choices.items()
odict_items([('conv', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))), ('conv2', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))), ('pool', MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False))])
# 只看键
>>> choices.keys() 
odict_keys(['conv', 'conv2', 'pool'])

# 只看值
>>> choices.values()
odict_values([Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)])

# 更新现有的结构
>>> choices.update({'pool':nn.MaxPool2d(5)})
>>> choices
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
# pop 删除一个
>>> choices.pop('pool')
MaxPool2d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
>>> choices
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
)
# 清空所有
>>> choices.clear()
>>> choices
ModuleDict()

源码:

class ModuleDict(Module):

    def clear(self) -> None:
        """Remove all items from the ModuleDict.
        """
        self._modules.clear()

    def pop(self, key: str) -> Module:
        r"""Remove key from the ModuleDict and return its module.

        Args:
            key (string): key to pop from the ModuleDict
        """
        v = self[key]
        del self[key]
        return v

    @_copy_to_script_wrapper
    def keys(self) -> Iterable[str]:
        r"""Return an iterable of the ModuleDict keys.
        """
        return self._modules.keys()

    @_copy_to_script_wrapper
    def items(self) -> Iterable[Tuple[str, Module]]:
        r"""Return an iterable of the ModuleDict key/value pairs.
        """
        return self._modules.items()

    @_copy_to_script_wrapper
    def values(self) -> Iterable[Module]:
        r"""Return an iterable of the ModuleDict values.
        """
        return self._modules.values()

    def update(self, modules: Mapping[str, Module]) -> None:

        # 更新
        if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
            for key, module in modules.items():
                self[key] = module

4 nn.ParameterList

函数们:

Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析_第4张图片

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

extra_repr函数

def extra_repr(self) -> str:
    child_lines = []
    for k, p in self._parameters.items():
        size_str = 'x'.join(str(size) for size in p.size())
        device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
        parastr = 'Parameter containing: [{} of size {}{}]'.format(
            torch.typename(p), size_str, device_str)
        child_lines.append('  (' + str(k) + '): ' + parastr)
    tmpstr = '\n'.join(child_lines)
    return tmpstr

# 用法:
>>> params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
>>> params.extra_repr()
'  (0): Parameter containing: [torch.FloatTensor of size 10x10]\n  (1): Parameter containing: [torch.FloatTensor of size 10x10]\n  (2): Parameter containing: [torch.FloatTensor of size 10x10]\n  (3): Parameter containing: [torch.FloatTensor of size 10x10]\n  (4): Parameter containing: [torch.FloatTensor of size 10x10]\n  (5): Parameter containing: [torch.FloatTensor of size 10x10]\n  (6): Parameter containing: [torch.FloatTensor of size 10x10]\n  (7): Parameter containing: [torch.FloatTensor of size 10x10]\n  (8): Parameter containing: [torch.FloatTensor of size 10x10]\n  (9): Parameter containing: [torch.FloatTensor of size 10x10]'

5 nn.ParameterDict

函数们:

Pytorch容器:nn.Sequential,ModuleList,ParameterList源码解析_第5张图片

用法:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterDict({
            'left': nn.Parameter(torch.randn(5, 10)),
            'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

你可能感兴趣的:(pytorch,python,深度学习)