pytorch中获取模型参数:state_dict和parameters两个方法的差异比较

一、本文的模型案例

代码如下:

import torch
import torch.nn.functional as F
from torch.optim import SGD
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)
 
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet() # 构造模型

二、model.state_dict()方法

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)。这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。

2.1 Module的层的权值以及bias查看

print(type(model.state_dict()))  # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict”
 
for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
'''
conv1.weight     torch.Size([32, 3, 3, 3])
conv1.bias       torch.Size([32])
conv2.weight     torch.Size([32, 3, 3, 3])
conv2.bias       torch.Size([32])
dense1.weight    torch.Size([128, 288])
dense1.bias      torch.Size([128])
dense2.weight    torch.Size([10, 128])
dense2.bias      torch.Size([10])
'''

2.2 优化器optimizer的state_dict()方法

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

optimizer = SGD(model.parameters(),lr=0.001,momentum=0.9)
for var_name in optimizer.state_dict():
    print(var_name,'\t',optimizer.state_dict()[var_name])
 
'''
state    {}
param_groups     [{'lr': 0.001, 
                   'momentum': 0.9,
                   'dampening': 0, 
                   'weight_decay': 0, 
                   'nesterov': False, 
                   'params': [1412966600640, 1412966613064, 1412966613136, 1412966613208, 
                             1412966613280, 1412966613352, 1412966613496, 1412966613568]
                 }]
'''

三、model.parameters()方法

这个方法也会获得模型的参数信息,如下:

print(type(model.parameters()))  # 返回的是一个generator
 
for para in model.parameters():
    print(para.size())  # 只查看形状
'''
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([128, 288])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
'''

从这里可以看出,其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。

其实Module还有一个与parameters类似的函数,named_parameters,而且parameters正是通过named_parameters来实现的,

看一下parameters的定义,很简单:
 

def parameters(self, recurse=True):
    for name, param in self.named_parameters(recurse=recurse):
        yield param

来一起看一下named_parameters的简单使用。

print(type(model.named_parameters()))  # 返回的是一个generator
 
for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple 
    '''
    是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
    '''
    print(para[0],'\t',para[1].size())
 
'''
conv1.weight     torch.Size([32, 3, 3, 3])
conv1.bias       torch.Size([32])
conv2.weight     torch.Size([32, 3, 3, 3])
conv2.bias       torch.Size([32])
dense1.weight    torch.Size([128, 288])
dense1.bias      torch.Size([128])
dense2.weight    torch.Size([10, 128])
dense2.bias      torch.Size([10])
'''

总结:model.state_dict()、model.parameters()、model.named_parameters()这三个方法都可以查看Module的参数信息,用于更新参数,或者用于模型的保存。

你可能感兴趣的:(pytorch中获取模型参数:state_dict和parameters两个方法的差异比较)