pytorch保存模型pth_pytorch模型文件pth详解

1.pth文件中保存的是什么

import torch

state_dict = torch.load("resnet18.pth")

print(type(state_dict))

---------------

如上打印输出所示,pth文件通过有序字典来保持模型参数。有序字典与常规字典一样,但是在排序操作方面有一些额外的功能。常规的dict是无序的,OrderedDict能够比dict更好地处理频繁的重新排序操作。

OrderedDict有一个方法

import torch

state_dict = torch.load("resnet18.pth")

print(type(state_dict))

for i in state_dict:

print(i)

print(type(state_dict[i]))

print("aa:",state_dict[i].data.size())

print("bb:",state_dict[i].requires_grad)

break

------------------------------

conv1.weight

aa: torch.Size([64, 3, 7, 7])

bb: True

如上打印所示,有序字典state_dict中每个元素都是Parameter参数,该参数是一种特殊的张量,包含data和requires_grad两个方法。其中data字段保存的是模型参数,requires_grad字段表示当前参数是否需要进行反向传播。

2.torch.save()

先建立一个字典,保存三个参数:调用torch.save(),即可保存对应的pth文件。需要注意的是若模型是由nn.Moudle类继承的模型,保存pth文件时,state_dict参数需要由model.state_dict指定。

state_dict = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}

torch.save(state_dict , dir)

--------------------------------

torch.save(model.state_dict,dir)

3.torch.load()

当你想恢复某一阶段的训练(或者进行测试)时,那么就可以读取之前保存的网络模型参数等。

checkpoint = torch.load(dir)

model.load_state_dict(checkpoint['net'])

optimizer.load_state_dict(checkpoint['optimizer'])

start_epoch = checkpoint['epoch'] + 1

你可能感兴趣的:(pytorch保存模型pth)