torch.nn.parameter详解

:--

      • 目录:
        • 参考:
        • 1、parameter基本解释:
        • 2、参数requires_grad的深入理解:
          • 2.1 Parameter级别的requires_grad
          • 2.2Module级别的requires_grad标志

目录:

参考:

Parameter — PyTorch 1.12 documentation

1、parameter基本解释:

CLASS torch.nn.parameter.Parameter(data=None, requires_grad=True)
"""
A kind of Tensor that is to be considered a module parameter.

Parameters are Tensor subclasses, that have a very special property when used with Module s - when they’re assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator. Assigning a Tensor doesn’t have such effect. This is because one might want to cache some temporary state, like last hidden state of the RNN, in the model. If there was no such class as Parameter, these temporaries would get registered too.

data (Tensor) – parameter tensor.

requires_grad (bool, optional) – if the parameter requires gradient. See Locally disabling gradient computation for more details. Default: True
"""

torch.nn.parameter.Parameter 类用于Module里面自定义参数,当其作为Module的属性时,会自动添加到模型的参数列表中,可以通过parameters()迭代器读取:例如RNN的最后一个隐藏状态,Transfermor、VIT、GNN都会用到的。

参数data:指的是Tensor

参数requires_grad:指的是是否需要自动计算梯度(根据实际情况来定,如果需要学习的权重,需要自动计算梯度,如果不参与学习,只是作为保存变量则不需要自动计算梯度)

2、参数requires_grad的深入理解:

2.1 Parameter级别的requires_grad

Autograd mechanics — PyTorch 1.12 documentation

requires_grad参数和pytorch的自动计算梯度的机制有关。requires_grad是一个决定是否需要反向传播时候计算梯度的标志,如果True,则在前向传递期间,将节点记录在后向图中。在后向传递 (.backward()) 时只有 requires_grad=True 的叶张量才会将梯度累积到它们的 .grad 字段中。 注意:即使每个张量都有这个标志,设置它只对leaf tensors(没有 grad_fn 的张量,例如,nn.Module 的参数)有意义。很明显所有no leaf tensors(具有 grad_fn 的张量,与leaf tensors有关的后向图的张量)都会自动具有 require_grad=True,no leaf tensors计算梯度作为中间结果来计算叶tensors的 grad 。 设置 requires_grad 可以控制模型的哪些部分需要梯度计算。举个例子:

例如,如果需要在模型微调期间冻结部分预训练模型。 要冻结模型的某些部分,只需将 .requires_grad(False) 应用于应用于不想更新的参数。如上所述,由于使用这些参数作为输入的计算不会记录在前向传递中,因此它们不会在后向传递中更新其 .grad 字段,因为它们不会成为第一个后向图的一部分节点。

2.2Module级别的requires_grad标志

根据需要,也可以使nn.Module.requires_grad() 在模块级别设置requires_grad。当应用于模块时, .requires_grad_() 会影响模块的所有参数(默认情况下 requires_grad=True )。

你可能感兴趣的:(软件开发相关的技能,深度学习,神经网络,pytorch)