【pytorch】named_parameters()、parameters()、state_dict()==>给出网络的名字和参数的迭代器

torch中存在3个功能极其类似的方法,它们分别是model.parameters()、model.named_parameters()、model.state_dict(),下面就具体来说说这三个函数的差异:

1.首先,说说比较接近的model.parameters()和model.named_parameters()。这两者唯一的差别在于,
named_parameters()返回的list中,每个元组(与list相似,只是数据不可修改)打包了2个内容,分别是layer-name和layer-param(网络层的名字和参数的迭代器)
而parameters()只有后者layer-param(参数的迭代器)

2.后面只谈model.named_parameters()和model.state_dict()间的差别。

它们的差异主要体现在3方面:

  1. 返回值类型不同
  2. 存储的模型参数的种类不同
  3. 返回的值的require_grad属性不同
named_parameters() state_dict()
将layer_name : layer_param的键值信息打包成一个元祖然后再存到list当中 将layer_name : layer_param的键值信息存储为dict形式
只保存可学习、可被更新的参数,model.buffer()中的参数不包含在model.named_parameters()中 存储的是该model中包含的所有layer中的所有参数
require_grad属性都是True 存储的模型参数tensor的require_grad属性都是False

为何model.parameters()迭代出来的所有参数的require_grad属性都是True,因为它们在被创建时,默认的require_grad就是True。这也符合逻辑,即,使用nn.Parameter()创建的变量是模型参数,本就是要参与学习和更新的;而不参与模型学习和更新的呢?我们再看一个例子:

if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

定义的三个参数是不参与模型学习的,所以定义在buffer中,buffer的定义方法为self.register_buffer(),所以模型不参与学习和更新的参数是这样定义的、

pytorch中state_dict()和named_parameters()的差别,以及model.buffer/model.parameter_甘如荠-CSDN博客_named_parameters

你可能感兴趣的:(Pytorch相关,pytorch,网络,人工智能)