模型设置断点重新训练

模型在训练过程中因为某些原因中断,想要在上次训练的基础之上接着训练,则需要提前在训练过程中保存相应的状态,具体如下:

首先要在训练过程中保存模型参数、当前的epoch数以及相应的优化器参数,因为在训练过程中相应的优化器相关参数会不断进行更新,因此不能只保存模型参数,当前训练期数epoch以及相应的优化器。

# 下述代码加在训练的周期迭代过程最后即可
# Save models checkpoints
    state_G_A2B = {
        'epoch': epoch,  # 当前训练期数
        'net': netG_A2B.state_dict(),  # 网络参数
        'optimizer': optimizer_G.state_dict(),  # 优化器相关参数
    } 
    torch.save(state_G_A2B, 'output/netG_A2B_%d.pth'%{epoch})

(1)设置参数

parser.add_argument('--resume', type=bool, default=True, help='whether to resume training')

(2)判断resume参数是否为True并进行相应的初始化

注意,在对优化器、网络参数以及epoch进行更新赋值时,再赋值前必须先有其相关定义。

if not opt.resume:
    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

else:
    # Load state dicts
    # G_A2B保存的权重(包含epoch、权重以及优化器)
    checkpoint_G_A2B=torch.load('')
    # 加载上次训练结束时所在的epoch
    opt.epoch=checkpoint_G_A2B['epoch']
    netG_A2B.load_state_dict(checkpoint_G_A2B['net'])
    optimizer_G.load_state_dict(checkpoint_G_A2B['optimizer'],strict=False)
    # G_B2A保存的权重(包含epoch、权重以及优化器)
    checkpoint_G_B2A = torch.load('')
    # D_A保存的权重(包含epoch、权重以及优化器)
    checkpoint_D_A = torch.load('')
    netD_A.load_state_dict(checkpoint_D_A['net'])
    optimizer_D_A.load_state_dict(checkpoint_D_A['optimizer'],strict=False)
    # D_B保存的权重(包含epoch、权重以及优化器)
    checkpoint_D_B = torch.load('')
    netD_B.load_state_dict(checkpoint_D_B['net'])
    optimizer_D_B.load_state_dict(checkpoint_D_B['optimizer'],strict=False)

你可能感兴趣的:(python,算法,开发语言)