Parameter源码如下
class Parameter(torch.Tensor, metaclass=_ParameterMeta):
r"""A kind of Tensor that is to be considered a module parameter.
Parameters are :class:`~torch.Tensor` subclasses, that have a
very special property when used with :class:`Module` s - when they're
assigned as Module attributes they are automatically added to the list of
its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator.
Assigning a Tensor doesn't have such effect. This is because one might
want to cache some temporary state, like last hidden state of the RNN, in
the model. If there was no such class as :class:`Parameter`, these
temporaries would get registered too.
Args:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. See
:ref:`locally-disable-grad-doc` for more details. Default: `True`
"""
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.empty(0)
if type(data) is torch.Tensor or type(data) is Parameter:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return torch.Tensor._make_subclass(cls, data, requires_grad)
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t = data.detach().requires_grad_(requires_grad)
if type(t) is not type(data):
raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} "
"requires that detach() returns an instance of the same type, but return "
f"type {type(t).__name__} was found instead. To use the type as a "
"Parameter, please correct the detach() semantics defined by "
"its __torch_dispatch__() implementation.")
t._is_param = True
return t
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
# are still considered that custom tensor type and these methods will not be called for them.
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)
__torch_function__ = _disabled_torch_function_impl
对源码中的第一段注释的理解
nn.parameter是一种绑定到某一module的参数列表中的Tensor,是继承自torch.Tensor的子类,当其在nn.module中使用时,nn.parameter会被自动添加到该模块的parameter list中,即加入到parameter()这个迭代器中,作为可训练的参数.
可以将其看做为一个类型转换的函数,将一个不可训练的Tensor转换为一个可训练的parameter,并且将这个parameter与该module进行绑定,被送入优化器中,成为该模型的一部分,随着训练不断进行更新(requires_grad == True 的情况)。
而若只是声明一个tensor不会产生如上效果,torch.tensor([1,2,3],requires_grad=True)只是将参数变成可训练的,并没有绑定在module的parameter列表中,且tensor的requires_grad参数默认为False。
参数 data、requires_grad
data参数默认为一维的tensor0,在使用时通常放入一个初始化的值,但要对其尺寸进行准确的指定,e.g.
nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.Parameter(torch.rand(input_size, output_size))
requires_grad参数是一个bool值, 默认值为True,当设定为True为随着训练器的迭代而不断更新,设置为False为 不可训练的,e.g ViT中的位置编码是需要固定住的,在实例化时就需要声明该参数为false
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
使用方式
若直接像如下方式声明是错误的,未对其进行实例化
nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.Parameter(torch.rand(input_size, output_size))
正确的使用方式如下(ViT、nn.Linear源码中的声明方式)
# ViT
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
# nn.Linear
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
注意到register_parameter的声明方式,若不在module内部声明parameter,在外部也想将其放入parameter列表中,就可以使用该方法
pytorch官网中给出的解释如下
第一个参数为给该参数起的名字,第二个参数为nn.parameter类型的数据,不想设定可以直接写None
若需要设定,其基本使用方法如下
ran = nn.Parameter(torch.rand(6, 2))
model.register_parameter('bias',ran)
参考链接:
https://blog.csdn.net/hxxjxw/article/details/107904012
https://blog.csdn.net/m0_46653437/article/details/112444979
https://pytorch.org/docs/1.2.0/nn.html#torch.nn.Module.register_parameter
https://blog.csdn.net/weixin_44966641/article/details/118730730