pytorch保存模型方法

Pytorch 有两种保存模型的方式,都是通过调用pickle序列化方法实现的。

第一种方法只保存模型参数。第二种方法保存完整模型。推荐使用第一种,第二种方法可能在切换设备和目录的时候出现各种问题。

1.保存模型参数方法:

print(model.state_dict().keys())                                # 输出模型参数名称

# 保存模型参数到路径"./data/model_parameter.pkl"
torch.save(model.state_dict(), "./data/model_parameter.pkl")
new_model = Model()                                                    # 调用模型Model
new_model.load_state_dict(torch.load("./data/model_parameter.pkl"))    # 加载模型参数     
new_model.forward(input)                                               # 进行使用

2.保存完整模型(不推荐)

torch.save(model, './data/model.pkl')        # 保存整个模型
new_model = torch.load('./data/model.pkl')   # 加载模型

3.Transfomers库预训练模型的加载

# 使用transformers预训练后进行保存
model.save_pretrained(model_path)                              
tokenizer.save_pretrained(tokenizer_path)

# 预训练模型使用 `from_pretrained()` 重新加载
model.from_pretrained(model_path)                              
tokenizer.from_pretrained(tokenizer_path)

你可能感兴趣的:(pytorch,pytorch)