pytorch的模型保存和读取

pytorch的模型保存和读取

torch.save: 	Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
torch.load: 	Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
torch.nn.Module.load_state_dict: 	Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.
torch.save:将序列化的对象保存到磁盘。此函数使用Python的 pickle实用程序进行序列化。使用此功能可以保存各种对象的模型,张量和字典。
torch.load:使用pickle的解腌功能将腌制的目标文件反序列化到内存中。
torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典 。

保存权重和加载权重

save

torch.save(model.state_dict(), PATH)

load

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

保存/加载整个模型

保存:
torch.save(model, PATH)
加载:
model = torch.load(PATH)
model.eval()

保存和加载超参数

保存:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
加载:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

迁移学习的加载模型

部分加载模型或加载部分模型
方法1:先加载参数,再改掉模型

model = TheModelBClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH), strict=False)

strict=False表示可以不准确的匹配,以忽略不匹配的键

inchannel = model.fc.in_features  # 这里拿出来网络中fc层的输入层数
model.fc = nn.Linear(inchannel, 5) # 把网络里面的fc层改掉

方法2:
torch.load(PATH)#先把权重载入到内存中
在把要更改的层删掉
再在载入到模型修改好的模型中

这两种方式都行

节省CPU,直接将权重载入到GPU

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

你可能感兴趣的:(pytorch)