Pytorch-数据保存载入

pytorch中保存模型相关的函数有3个:

torch.save:利用python的pickle模块实现序列化并保存序列化后的object

torch.load:利用pickle将保存的object反序列化

torch.nn.Module.load_state_dict:通过反序列化得到的state_dict读取保存的训练参数

有两种方法保存模型:

1. torch.save(model, path) # 直接保存整个模型
2. torch.save(model.state_dict(), path) # 保存模型的参数

相应地有两种方法加载保存的模型:

model = torch.load(path) # 直接加载模型
model = Model()                         # 先初始化一个模型
model.load_state_dict(torch.load(path)) # 再加载模型参数

  state_dict

Pytorch-数据保存载入_第1张图片

 net._modules.items()

Pytorch-数据保存载入_第2张图片

你可能感兴趣的:(Torch,pytorch,python,深度学习)