Pytorch模型保存与加载模型继续训练

1. 网络模型定义与模型参数保存

定义网络模型与基本参数,以及模型训练和模型保存

使用torch.save()方法保存模型

在save_dict={}中可以保存epoch,model,optimizer,scheduler,loss等参数。

my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()

for p in my_net.parameters():
    p.requires_grad = True
bestacc = 0.0
savepth = 'mySavepthPath'
for epoch in range(n_epoch):
    my_net.train()
    ....
    if acc > bestacc:
        save_dict = {
            'epoch': epoch,
            'model': my_net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(save_dict, savepth + '.pth')

2. 加载模型继续训练

使用torch.load加载模型,完整代码如下。

要注意的是,要先定义模型和优化器optimizer,把模型放到gpu上,然后再加载模型。
否则执行optimizer.step()时会出现下面这个错误。
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 
my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()

Resume = True
start_epoch = -1
if Resume:
    path_checkpoint = 'mySavepthPath.pth'
    checkpoint = torch.load(path_checkpoint, map_location=torch.device('cuda'))
    my_net.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch']
    print("start_epoch:", start_epoch)
    print('-----------------------------')


for p in my_net.parameters():
    p.requires_grad = True

bestacc = 0.0
savepth = 'mySavepthPath'

new_start = 0 if start_epoch == -1 else start_epoch
for epoch in range(start_epoch + 1, new_start+n_epoch):
    my_net.train()
    ....
    if acc > bestacc:
        save_dict = {
            'epoch': epoch,
            'model': my_net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(save_dict, savepth + '.pth')

你可能感兴趣的:(pytorch,深度学习,python)