pytorch - 模型保存及加载

一、pytorch 模型保存及加载

经常看到(.pt,.pth,.pkl)的pytorch模型文件,并不是格式上不同,只是后缀不同;
torch.save函数保存模型文件时,因人而异;
重点在于保存模型的方式不同,需要注意

1.1、只保存模型参数,不保存模型结构

保存:
    # 模型权重参数,不保存模型结构,速度快,占空间少
    torch.save(model.state_dict(), "mymodel.pth")    
调用:
    # 这里需要重新模型结构,My_model
    model = My_model(*args, **kwargs)     
    # 这里根据模型结构,调用存储的模型参数
    model.load_state_dict(torch.load(mymodel.pth))  
    model.eval()

1.2、保存整个模型,包括模型结构和模型参数

保存:
    # 保存整个model的状态
    torch.save(model, mymodel.pth)
调用:
    # 这里已经不需要重构模型结构了,直接load就可以
    model=torch.load(mymodel.pth)
    model.eval()

1.3、保存更多信息,如优化器参数

1)保存信息至字典,获取时通过字典获取

保存:
    torch.save({'epoch': epochID + 1,
                'state_dict':model.state_dict(),
                'best_loss': lossMIN,
                'optimizer': optimizer.state_dict(),
                'alpha': loss.alpha,
                'gamma': loss.gamma},
                checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f"% lossMIN) + '.pth.tar')
调用:
    def load_checkpoint(model, checkpoint_PATH, optimizer):
        if checkpoint != None:
            model_CKPT = torch.load(checkpoint_PATH)
            model.load_state_dict(model_CKPT['state_dict'])
            print('loading checkpoint!')
            optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer 

2)如若修改了网络结构,如增删操作,则需要过滤这些参数,加载方式略有不同

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
    if checkpoint != 'No':
        print("loading checkpoint...")
        model_dict = model.state_dict()
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        # 过滤操作
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # 打印出来,更新了多少的参数
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
        model.load_state_dict(model_dict)
        print("loaded finished!")
        # 如果不需要更新优化器那么设置为false
        if loadOptimizer == True:
            optimizer.load_state_dict(modelCheckpoint['optimizer'])
            print('loaded! optimizer')
        else:
            print('not loaded optimizer')
    else:
        print('No checkpoint is included')
    return model, optimizer

1.4、冻结部分参数,训练另一部分参数(special)

需求较少,后续添加

你可能感兴趣的:(模型训练,pytorch,python,深度学习)