Pytorch学习(4) —— nn.Sequential nn.ModuleList nn.ModuleDict 源码解析

下面给出模型基本类相关的源码分析。

本节大部分函数重载用法与上节内容相似,在本节一样的内容将不会进行详细描述

文章目录

  • 1 nn.Sequential
  • 2 nn.ModuleList
  • 3 nn.ModuleDict
  • 总结

1 nn.Sequential

这个类可以快速的构建一个模型,下面是官方给的一个类。

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

这个就是一个序列化模型,可以直接forward测试,执行时候也是按照顺序进行执行,同时可以使用角标对内部结构进行删除修改,下面给出对应的源码。

class Sequential(Module):
    def __init__(self, *args):
        super(Sequential, self).__init__()
        # Sequential 可以是个字典,这样每一层就按照定义的键命名
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else: # 否则按照序号命名
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    # 这个与上节类似,为了获取访问模型序列内部元素的角标
    def _get_item_by_idx(self, iterator, idx):
        size = len(self)
        idx = operator.index(idx)
        if not -size <= idx < size:
            raise IndexError('index {} is out of range'.format(idx))
        idx %= size
        return next(islice(iterator, idx, None))

    # 使用角标访问序列,比如S[i]或切片访问S[i:j],用法与ParameterList相似。
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return self.__class__(OrderedDict(list(self._modules.items())[idx]))
        else:
            return self._get_item_by_idx(self._modules.values(), idx)

    # 使用角标设置模型序列,仅能使用角标设置,不支持切片。
    def __setitem__(self, idx, module):
        key = self._get_item_by_idx(self._modules.keys(), idx)
        return setattr(self, key, module)

    # 删除模型序列,可以使用角标,也可以使用切片
    def __delitem__(self, idx):
        if isinstance(idx, slice):
            for key in list(self._modules.keys())[idx]:
                delattr(self, key)
        else:
            key = self._get_item_by_idx(self._modules.keys(), idx)
            delattr(self, key)

    # 可以使用len(S)获取模型个数
    def __len__(self):
        return len(self._modules)

    # 返回类里面函数名
    def __dir__(self):
        keys = super(Sequential, self).__dir__()
        keys = [key for key in keys if not key.isdigit()]
        return keys

    # 模型的执行方法,很显然,按照模型定义的顺序执行
    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input

如果想像Parameter类等可以输出相似的参数信息,自己在文件中补充函数即可。

2 nn.ModuleList

这个就是创建一个模型列表,下面给出官网的一个例子,构造一个模型列表,方便前向执行时访问目标层,注意,这个不是模型序列,这个只是存模型各种层的一个List,所以不存在前向的问题

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

下面给出这个类的源码(说真的,与上一节内容太相似了,真不想写了,o(╥﹏╥)o)。

class ModuleList(Module):
    def __init__(self, modules=None):
        super(ModuleList, self).__init__()
        if modules is not None:
            self += modules # 重载了+= 调用extend函数
    
    # 验证并获取访问所需的角标
    def _get_abs_string_index(self, idx):
        idx = operator.index(idx)
        if not (-len(self) <= idx < len(self)):
            raise IndexError('index {} is out of range'.format(idx))
        if idx < 0:
            idx += len(self)
        return str(idx)

    # 同上
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return self.__class__(list(self._modules.values())[idx])
        else:
            return self._modules[self._get_abs_string_index(idx)]
    
    # 同上
    def __setitem__(self, idx, module):
        idx = self._get_abs_string_index(idx)
        return setattr(self, str(idx), module)

    # 删除,方法同上
    def __delitem__(self, idx):
        if isinstance(idx, slice):
            for k in range(len(self._modules))[idx]:
                delattr(self, str(k))
        else:
            delattr(self, self._get_abs_string_index(idx))
        # 为了保持序号,删除后将会重新构建模型的index
        str_indices = [str(i) for i in range(len(self._modules))]
        self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))

    # 获取长度
    def __len__(self):
        return len(self._modules)

    # 获取模型迭代器
    def __iter__(self):
        return iter(self._modules.values())

    # 重载 += 符号
    def __iadd__(self, modules):
        return self.extend(modules)

    def __dir__(self):
        keys = super(ModuleList, self).__dir__()
        keys = [key for key in keys if not key.isdigit()]
        return keys

    # 在index处插入一个模型module
    def insert(self, index, module):
        for i in range(len(self._modules), index, -1):
            self._modules[str(i)] = self._modules[str(i - 1)]
        self._modules[str(index)] = module

    # 在list的尾部插入一个模型module
    def append(self, module):
        self.add_module(str(len(self)), module)
        return self

    # 在模型尾部插入一个list模型modules
    def extend(self, modules):
        if not isinstance(modules, container_abcs.Iterable):
            raise TypeError("ModuleList.extend should be called with an "
                            "iterable, but got " + type(modules).__name__)
        offset = len(self)
        for i, module in enumerate(modules):
            self.add_module(str(offset + i), module)
        return self

3 nn.ModuleDict

模型字典,这里面存了一大堆模型信息,与ModuleList相似,这个只是个容器,不涉及前向执行问题,下面给出官方的例子。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

下面是这个类对应的源码,与ParameterDict的源码几乎一模一样^_^。到这为止,大部分函数前面都介绍过,都能看懂干啥的,我只补上新的函数

class ModuleDict(Module):
    def __init__(self, modules=None):
        super(ModuleDict, self).__init__()
        if modules is not None:
            self.update(modules)

    def __getitem__(self, key):
        return self._modules[key]

    def __setitem__(self, key, module):
        self.add_module(key, module)

    def __delitem__(self, key):
        del self._modules[key]

    def __len__(self):
        return len(self._modules)

    def __iter__(self):
        return iter(self._modules)

    def __contains__(self, key):
        return key in self._modules

    def clear(self):
        self._modules.clear()

    def pop(self, key):
        v = self[key]
        del self[key]
        return v

    def keys(self):
        return self._modules.keys()

    def items(self):
        return self._modules.items()

    def values(self):
        return self._modules.values()

    def update(self, modules):
        if not isinstance(modules, container_abcs.Iterable):
            raise TypeError("ModuleDict.update should be called with an "
                            "iterable of key/value pairs, but got " +
                            type(modules).__name__)

        if isinstance(modules, container_abcs.Mapping):
            if isinstance(modules, (OrderedDict, ModuleDict)):
                for key, module in modules.items():
                    self[key] = module
            else:
                for key, module in sorted(modules.items()):
                    self[key] = module
        else:
            for j, m in enumerate(modules):
                if not isinstance(m, container_abcs.Iterable):
                    raise TypeError("ModuleDict update sequence element "
                                    "#" + str(j) + " should be Iterable; is" +
                                    type(m).__name__)
                # 其实这里想表达的是,如果输入的不是一对键值,也可以输入一个矩阵,个数为2
                # 然后以第一个为键,第二个为值
                # 个人觉得这个更像是为了考虑各种操作而写的函数
                # 基本保证你怎么写都能执行.......
                if not len(m) == 2:
                    raise ValueError("ModuleDict update sequence element "
                                     "#" + str(j) + " has length " + str(len(m)) +
                                     "; 2 is required")
                self[m[0]] = m[1]

总结

这节内容完成起来太容易了,因为大部分函数相似度太强了,如果本部分有没看懂的,一定要结合上一个博客一起看。

如果还有不明朗的,欢迎评论区讨论。

你可能感兴趣的:(Pytorch学习,python,机器学习,人工智能)