下面给出模型基本类相关的源码分析。
本节大部分函数重载用法与上节内容相似,在本节一样的内容将不会进行详细描述
这个类可以快速的构建一个模型,下面是官方给的一个类。
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
这个就是一个序列化模型,可以直接forward测试,执行时候也是按照顺序进行执行,同时可以使用角标对内部结构进行删除修改,下面给出对应的源码。
class Sequential(Module):
def __init__(self, *args):
super(Sequential, self).__init__()
# Sequential 可以是个字典,这样每一层就按照定义的键命名
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else: # 否则按照序号命名
for idx, module in enumerate(args):
self.add_module(str(idx), module)
# 这个与上节类似,为了获取访问模型序列内部元素的角标
def _get_item_by_idx(self, iterator, idx):
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError('index {} is out of range'.format(idx))
idx %= size
return next(islice(iterator, idx, None))
# 使用角标访问序列,比如S[i]或切片访问S[i:j],用法与ParameterList相似。
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
# 使用角标设置模型序列,仅能使用角标设置,不支持切片。
def __setitem__(self, idx, module):
key = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
# 删除模型序列,可以使用角标,也可以使用切片
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
# 可以使用len(S)获取模型个数
def __len__(self):
return len(self._modules)
# 返回类里面函数名
def __dir__(self):
keys = super(Sequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
# 模型的执行方法,很显然,按照模型定义的顺序执行
def forward(self, input):
for module in self._modules.values():
input = module(input)
return input
如果想像Parameter类等可以输出相似的参数信息,自己在文件中补充函数即可。
这个就是创建一个模型列表,下面给出官网的一个例子,构造一个模型列表,方便前向执行时访问目标层,注意,这个不是模型序列,这个只是存模型各种层的一个List,所以不存在前向的问题。
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
下面给出这个类的源码(说真的,与上一节内容太相似了,真不想写了,o(╥﹏╥)o)。
class ModuleList(Module):
def __init__(self, modules=None):
super(ModuleList, self).__init__()
if modules is not None:
self += modules # 重载了+= 调用extend函数
# 验证并获取访问所需的角标
def _get_abs_string_index(self, idx):
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
if idx < 0:
idx += len(self)
return str(idx)
# 同上
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
# 同上
def __setitem__(self, idx, module):
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)
# 删除,方法同上
def __delitem__(self, idx):
if isinstance(idx, slice):
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
else:
delattr(self, self._get_abs_string_index(idx))
# 为了保持序号,删除后将会重新构建模型的index
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
# 获取长度
def __len__(self):
return len(self._modules)
# 获取模型迭代器
def __iter__(self):
return iter(self._modules.values())
# 重载 += 符号
def __iadd__(self, modules):
return self.extend(modules)
def __dir__(self):
keys = super(ModuleList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
# 在index处插入一个模型module
def insert(self, index, module):
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
# 在list的尾部插入一个模型module
def append(self, module):
self.add_module(str(len(self)), module)
return self
# 在模型尾部插入一个list模型modules
def extend(self, modules):
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleList.extend should be called with an "
"iterable, but got " + type(modules).__name__)
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
模型字典,这里面存了一大堆模型信息,与ModuleList相似,这个只是个容器,不涉及前向执行问题,下面给出官方的例子。
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
下面是这个类对应的源码,与ParameterDict的源码几乎一模一样^_^。到这为止,大部分函数前面都介绍过,都能看懂干啥的,我只补上新的函数
class ModuleDict(Module):
def __init__(self, modules=None):
super(ModuleDict, self).__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key):
return self._modules[key]
def __setitem__(self, key, module):
self.add_module(key, module)
def __delitem__(self, key):
del self._modules[key]
def __len__(self):
return len(self._modules)
def __iter__(self):
return iter(self._modules)
def __contains__(self, key):
return key in self._modules
def clear(self):
self._modules.clear()
def pop(self, key):
v = self[key]
del self[key]
return v
def keys(self):
return self._modules.keys()
def items(self):
return self._modules.items()
def values(self):
return self._modules.values()
def update(self, modules):
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(modules).__name__)
if isinstance(modules, container_abcs.Mapping):
if isinstance(modules, (OrderedDict, ModuleDict)):
for key, module in modules.items():
self[key] = module
else:
for key, module in sorted(modules.items()):
self[key] = module
else:
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError("ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(m).__name__)
# 其实这里想表达的是,如果输入的不是一对键值,也可以输入一个矩阵,个数为2
# 然后以第一个为键,第二个为值
# 个人觉得这个更像是为了考虑各种操作而写的函数
# 基本保证你怎么写都能执行.......
if not len(m) == 2:
raise ValueError("ModuleDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) +
"; 2 is required")
self[m[0]] = m[1]
这节内容完成起来太容易了,因为大部分函数相似度太强了,如果本部分有没看懂的,一定要结合上一个博客一起看。
如果还有不明朗的,欢迎评论区讨论。