pytoch中class定义神经网络的参数保存与加载

一、定义一个容易识别的网络

在正式介绍模型的保存和加载之前,我们首先定义一个基本的网络Net,它只包含一个全连接层:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer = nn.Linear(1, 1)
        self.layer.weight = nn.Parameter(torch.FloatTensor([[10]]))
        self.layer.bias = nn.Parameter(torch.FloatTensor([1]))

    def forward(self, x):
         y = self.layer(x)
        return y

二、保存Net的参数值

保存模型参数之前,需要知道Net的参数值存储在其state_dict(状态字典)属性中,我们查看一下net的state_dict包含哪些参数:

print(net.state_dict())

我们将会得到net包含的所有参数名称与参数值

包含一个weight和一个bias,对应的值分别是10和1,和我们之前定义的全连接层一致。我们需要保存的就是这个state_dict,保存的函数为“torch.save()”,参数是我们需要保存的dict和存储路径

torch.save(obj=net.state_dict(), f="models/net.pth")

这有可能会报错TypeError: state_dict() missing 1 required positional argument: 'self'

报错原因是上面定义的class是一个类,不能直接obj=,正确的形式是obj=Net()

现在,同级目录models下将会出现net.pth文件,pth文件中的内容就是net的参数名称和值对应的state_dict,如下:

三、加载Net参数值并用于新的模型

最后一个步骤就是从pth文件中重新获取Net参数值,并把参数值装载到新定义的Model对象中。这里我们重新定义一个结构和Net类相同的类Model,区别仅仅是Model参数初始值和Net不同,代码如下:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer = nn.Linear(1, 1)
        self.layer.weight = nn.Parameter(torch.FloatTensor([[0]]))
        self.layer.bias = nn.Parameter(torch.FloatTensor([0]))

    def forward(self, x):
        out = self.layer(x)
        return out

这里将Model的初始值权重w和偏差都设置为0,查看其state_dict:

model = Model()
print(model.state_dict())

得到的w和b值与预期相同,均为0,如下:

现在,我们将model对象的参数值设置为net.pth中的值,需要使用“model.load_state_dict()”函数重置model的参数值为"torch.load(models/ net.pth)"中的参数值,如下:

model.load_state_dict(torch.load("models/net.pth"))
print(model.state_dict())

至此,model的w和b值就不再是0了,而是net中w和b对应的10和1,如下:

其中参数值重载的核心函数为“model.load_state_dict()”,每个继承自nn.Module的网络都能通过这个函数设定参数值。

四、优化器与epoch的保存

保存优化器参数值和epoch值的主要目的是用于继续训练,保存的流程依旧是先“torch.save()”再“torch.load_state_dict()”,我们首先定义一个Adam优化器、一个任意的epoch值与net如下:

net = Net()
Adam = optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999))
epoch = 96

现在,创建一个字典来保存所有的对象,并用save函数保存这个字典

all_states = {"net": net.state_dict(), "Adam": Adam.state_dict(), "epoch": epoch}
torch.save(obj=all_states, f="models/all_states.pth")

所有的对象都被保存到models文件夹下了:

可以使用load()函数把所有的对象再次提取出来:

reload_states = torch.load("models/all_states.pth")
print(reload_states)

五、总结

pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。其中最重要的部分是“字典”的概念,因为参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的。

你可能感兴趣的:(highway-env,python)