有时候自己写的网络模型在训练时想看某个epoch的模型参数,或者想按照某一epoch的模型参数进行测试,就需要看log相关的checkpoints。
方法如下:
在 PyTorch 中,可以使用 torch.save() 函数来保存模型的状态字典。可以使用一个自定义的 checkpoint 函数来在每个 epoch 结束时保存模型的状态字典。
import torch
import os
def checkpoint(model, epoch, optimizer, loss, checkpoint_dir):
"""
Saves a checkpoint of the model at a given epoch
"""
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': loss
}
filename = os.path.join(checkpoint_dir, f'checkpoint-epoch{epoch}.pt')
torch.save(state, filename)
# 训练过程中,在每个 epoch 结束时调用 checkpoint 函数保存模型状态字典
for epoch in range(30):
# 训练过程中的代码
# ...
if epoch > 20:
checkpoint(model, epoch, optimizer, loss, checkpoint_dir)
使用 torch.load() 函数可以从保存的状态字典中恢复模型的权重。示例代码如下:
# 加载某个 epoch 的模型
filename = os.path.join(checkpoint_dir, 'checkpoint-epoch25.pt')
state = torch.load(filename)
model.load_state_dict(state['state_dict'])
#写一个新的MLP网络,网络参数读取此前已经存的epoch=25时的网络模型
#假设我们已经将模型保存在名为 model_epoch25.pth 的文件中,以下是加载该模型并用它进行测试的示例
#网络模型和训练时保持一致就行
import torch
import torch.nn as nn
# 定义 MLP 网络
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
# 创建 MLP 模型实例
model = MLP()
# 加载模型状态
checkpoint = torch.load('model_epoch25.pth')
model.load_state_dict(checkpoint['state_dict'])
# 将模型设置为评估模式
model.eval()
with torch.no_grad():
# 进行测试
# 这里省略测试代码
在使用 PyTorch 模型进行测试或评估时,通常不需要计算梯度,因为我们不会更新模型的权重。因此,我们可以使用 torch.no_grad() 上下文管理器,将计算图上下文中的梯度计算禁用掉,以提高计算效率。
在深度学习模型训练过程中,我们通常需要保存模型的状态,以便在需要时可以恢复模型并继续训练或进行推断。通常,我们需要保存以下状态:
模型的权重或参数;
优化器的状态,包括学习率、动量等;
当前的训练 epoch;
当前的训练损失等。
为了方便地保存这些状态,通常将它们保存在一个 Python 字典中。在 PyTorch 中,通常会使用以下代码创建这个字典:
state = {
'epoch': epoch, # 当前训练 epoch
'state_dict': model.state_dict(), # 模型的权重或参数
'optimizer': optimizer.state_dict(), # 优化器的状态
'loss': loss # 当前训练损失
}
在上面的代码中,epoch 表示当前训练 epoch,model.state_dict() 返回模型的权重或参数的字典表示,optimizer.state_dict() 返回优化器的状态字典表示,loss 表示当前训练损失。将这些状态保存在一个字典中,可以方便地将它们保存到磁盘中,并在需要时加载到模型中。