为了更好理解Pytorch基本类的实现方法,我这里给出了关于参数方面的3个类的源码详解。
此部分可以更好的了解实现逻辑结构,有助于后续代码理解,学pytorch的话这个不是必须掌握的,看不懂也没关系。
此部分参考《pytorch源码阅读系列之Parameter类》,《通俗的讲解Python中的__new__()方法》
因为Parameter继承于torch.Tensor,没有新的变量和添加函数,只是对一些辅助函数进行了定义
Parameter作为Module类的参数,可以自动的添加到Module类的参数列表中,并且可以使用Module.parameters()提供的迭代器获取到,所以这个类是一切网络结构数据的核心。
class Parameter(torch.Tensor):
# 这个方法比__init__方法更先执行,这里就理解为一种初始化方法
# 详细参考《通俗的讲解Python中的__new__()方法》
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
# 为了方便实用deepcopy方法,对当前数据进行深拷贝,正常的copy方法只拷贝一层,
# 简单的来说list的list,最好用深拷贝。
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
# 一种可视化方法,给print使用
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
# 用于替代reduce方法
def __reduce_ex__(self, proto):
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
这个类实际上是将一个Parameter的List转为ParameterList,如下例所示[nn.Parameter(torch.randn(10, 10)) for i in range(10)]
类型是List,List的每个元素是Parameter,然后这个List作为参数传入这个类构造ParameterList类型。
ParameterList输入一定是一个Parameter的List,其他类型会报错,在注册时候就会提示元素不是Parameter类型。
parms = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
下面是对应的源码。
class ParameterList(Module):
def __init__(self, parameters=None): # parameters 是一个python list 类型
super(ParameterList, self).__init__()
# 这里的+=运算是经过重载的,__iadd__定义,可以看出,实际上是调用了extend方法,
# 将parameters 注册到_parameters中
if parameters is not None:
self += parameters
# 针对非slice的index,判断是否满足取值的条件,并返回对应角标字符串
def _get_abs_string_index(self, idx):
idx = operator.index(idx) # 判断输入角标的位置是否为整数
# 这里重载的__len__,返回_parameters的个数
if not (-len(self) <= idx < len(self)): # 判断是否在参数范围内
raise IndexError('index {} is out of range'.format(idx))
if idx < 0: #对于负号的问题,就选择倒数的元素
idx += len(self)
return str(idx) # 返回对应的角标字符串
# 使得这个类可以通过角标访问,比如P[i]这种
def __getitem__(self, idx):
if isinstance(idx, slice): # 判断这个角标是否为切片就是P[i:j]这种
# _parameters是OrderedDict类型,返回值转换为list,
# __class__表示转换为当前类型,所以,通过切片返回的List仍然是ParameterList类型
return self.__class__(list(self._parameters.values())[idx])
else:
idx = self._get_abs_string_index(idx) # 检验角标正确性
return self._parameters[str(idx)] # 返回一个数据,数据类型为Parameter
# 使得这个类可以通过角标访问,比如P[i] = Q这种,这里面不支持切片复制
def __setitem__(self, idx, param):
idx = self._get_abs_string_index(idx)
return self.register_parameter(str(idx), param)
# 重载len用法,可以使用len(P)统计list个数
def __len__(self):
return len(self._parameters)
# 重载迭代器算法,可以用于 for i in P这种
def __iter__(self):
return iter(self._parameters.values())
# 重载自加算法,比如P += Q,等价于 P.extend(Q)
def __iadd__(self, parameters):
return self.extend(parameters)
# 列出这个类所有的属性,重载后可以使用dir(P)
def __dir__(self):
keys = super(ParameterList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
# 在list末尾添加一个Parameter
def append(self, parameter):
self.register_parameter(str(len(self)), parameter)
return self
# 在list末尾添加一个Parameter的list,也可以是ParameterList类型
def extend(self, parameters):
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(parameters).__name__)
offset = len(self)
for i, param in enumerate(parameters):
self.register_parameter(str(offset + i), param)
return self
# 可以理解为ParameterList可视化方法,下面给一个调用的例子
# (0): Parameter containing: [torch.FloatTensor of size 10x10]
# (1): Parameter containing: [torch.FloatTensor of size 10x10]
# (2): Parameter containing: [torch.FloatTensor of size 10x10]
def extra_repr(self):
child_lines = []
for k, p in self._parameters.items():
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
ParameterDict 是一个字典类源码,与python的字典非常相似,下面就是字典的一个例子,输入参数是个普通字典,然后转换为ParameterDict类型。
params = nn.ParameterDict({ 'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10))})
下面给出这个类的源码,并对其进行详细分析理解。
class ParameterDict(Module):
def __init__(self, parameters=None):
super(ParameterDict, self).__init__()
if parameters is not None:
self.update(parameters) # 更新字典
def __getitem__(self, key): # 同上一节,可以使用键访问值
return self._parameters[key]
def __setitem__(self, key, parameter): # 同上一节,可以使用键设置值
self.register_parameter(key, parameter)
def __delitem__(self, key): # 删除某个键,可使用del删除
del self._parameters[key]
def __len__(self): # 返回字典个数
return len(self._parameters)
def __iter__(self): # 同上一节,可以得到迭代器,迭代器用键表示
return iter(self._parameters.keys())
def __contains__(self, key): # 判断当前key是否在字典中,重载关键字in, key in dict
return key in self._parameters
def clear(self): # 清空字典
self._parameters.clear()
def pop(self, key): # 删除某个键,并返回其值。
v = self[key]
del self[key]
return v
def keys(self): # 返回所有的键的名称
return self._parameters.keys()
def items(self): # 同字典的item用法
return self._parameters.items()
def values(self): # 返回所有的值
r"""Return an iterable of the ParameterDict values.
"""
return self._parameters.values()
def update(self, parameters): # 输入新的字典,更新当前的参数字典
if not isinstance(parameters, container_abcs.Iterable): # 保证输入一定是个字典
raise TypeError("ParametersDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(parameters).__name__)
if isinstance(parameters, container_abcs.Mapping): # 判断是不是一个Mapping类型
if isinstance(parameters, (OrderedDict, ParameterDict)): #判断是不是已知类型
for key, parameter in parameters.items():
self[key] = parameter
else:
for key, parameter in sorted(parameters.items()):
self[key] = parameter
else:
# 感觉这里是为了适应其他的字典类,毕竟有可能用户自己也写个字典类
for j, p in enumerate(parameters):
if not isinstance(p, container_abcs.Iterable):
raise TypeError("ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(p).__name__)
if not len(p) == 2:
raise ValueError("ParameterDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) +
"; 2 is required")
self[p[0]] = p[1]
# 字典可视化
# (left): Parameter containing: [torch.FloatTensor of size 5x10]
# (right): Parameter containing: [torch.FloatTensor of size 5x10]
def extra_repr(self):
child_lines = []
for k, p in self._parameters.items():
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
child_lines.append(' (' + k + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
关于参数的三个类的分析就到这里了,其实感觉跟正常的python用法也没啥区别,为了方便用户使用pytorch,官方重载了大量的函数,方便用户使用,很大程度上降低了使用难度。后续,我再对模型的几个类比如Sequential,ModuleList,ModuleDict进行分析,Module这个类我估计不会进行分析了,将近1000行,实现了太多太多功能,我觉得太底层了,就不分析了,如果有人感兴趣的话,欢迎一起讨论研究。