pytorch中保存网络模型的两种方式

一、只保存网络中的参数

保存:

torch.save(model.state_dict(), save_fp)

加载的时候需要先初始化一个模型,然后把文件中的参数恢复。

train_weights = torch.load(model_fp)
model = Model()
model.load_state_dict(model_weights)

这里load得到的是变量类型为OrderedDict(),也就是网络中的参数集合。

二、保存网络结构和参数

保存:

torch.save(model, save_fp)

加载:

model = torch.load(model_fp)

这里load的到的一个对象,类型是

你可能感兴趣的:(pytorch)