python模型保存:保存字典数据 checkpiont+ pth文件处理

        可以先看一下例子,下边是详细解释

checkpiont中最关的部分为 model.state_dict(),以下方法都围绕其展开

1.保存为字典

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)

加载

checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])

2.以上的加强版,保存是否为当前最好的结果

# 保存函数
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """Saves checkpoint to disk"""
    directory = "../models/%s/"%(args.name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + filename
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, '../models/%s/'%(args.name) + 'model_best.pth.tar')

# 用例
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

3.增强(将模型的保存与加载分别用两个函数实现,同时保存网络和优化器的参数)

保存

def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
    total_models = glob.glob(model_path + '*')
    if len(total_models) >= max_to_save:
        total_models.sort()
        os.remove(total_models[0])

    state_dict = {}
    state_dict["model_state_dict"] = model_state_dict
    state_dict["optimizer_state_dict"] = optimizer_state_dict
    

    torch.save(state_dict, model_path + '_' + str(epoch))
    print('models {} save successfully!'.format(model_path + '-' + str(epoch)))

## 使用
from check import loader, saver
saver(model.state_dict(), optimizer.state_dict(), model_save_path, epoch + 1, step=None, max_to_save=100)

加载

def loader(model_path):
    state_dict = torch.load(model_path)
    model_state_dict = state_dict["model_state_dict"]
    optimizer_state_dict = state_dict["optimizer_state_dict"]
    return model_state_dict, optimizer_state_dict

注:
pytorch的checkpoint主要用于节省训练模型过程中使用的内存(from torch.utils.checkpoint import checkpoint),将模型或其部分的激活值的计算方法保存为一个checkpoint,在前向传播中不保留激活值,而在反向传播中根据checkpoint重新计算一次获得激活值用于反向传播。



如果:RuntimeError: params/unet.pth is a zip archive (did you mean to use torch.jit.load()?

       是因为pytorch的版本不匹配造成的。

你可能感兴趣的:(其他,python,pytorch,深度学习)