python pyrotch模型中的参数访问

获取模型对象1

res = model.decoder

获取模型参数对象2

res = model.decoder.parameters()

提示:
我们可以通过 Module 类的 parameters() 或者 named_parameters ⽅法来访问所有参数(以迭代器的形式返回),后者除了返回参数 Tensor 外还会返回其名字。

Pytorch中state_dict()、named_parameters()和parameters()的区别:optim.step只能更新nn.parameter.Parameter类型的参数。不可学习参数将会通过Module.register_parameter()注册在self._buffers中。named_parameters 保存了参数名与具体值,因此可用于锁住某些层的参数,让其在训练的时候不更新参数

打印参数3

print(type(model.named_parameters()))
for name, param in model.named_parameters():
    print(name, param.size())

此时能在Debug4中看到param的值,也可获取模型的第i层,并查看其中的参数:

res = MLP.model
weight_of_i = list( res[i].parameters())[0]
bias_of_i = list( res[i].parameters())[1]

打印模型中的参数个数


  1. python pyrotch模型中的参数访问_第1张图片 ↩︎

  2. python pyrotch模型中的参数访问_第2张图片 ↩︎

  3. python pyrotch模型中的参数访问_第3张图片 ↩︎

  4. python pyrotch模型中的参数访问_第4张图片 ↩︎

你可能感兴趣的:(语言学习笔记,python,开发语言,后端)