PyTorch
模型存储学习到的参数在内部状态字典中,称为 state_dict
, 他们的持久化通过 torch.save
方法。
model = models.shufflenet_v2_x0_5(pretrained=True)
torch.save(model, "../../data/ShuffleNetV2_X0.5.pth")
如果要加载模型的话,首先需要实例化一个同类型的模型对象,然后用 load_state_dict() 方法加载参数。
model = models.shufflenet_v2_x0_5()
model.load_state_dict(torch.load("../../data/ShuffleNetV2_X0.5.pth"))
model.eval()
Output exceeds the size limit. Open the full output data in a text editor
ShuffleNetV2(
(conv1): Sequential(
(0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(stage2): Sequential(
(0): InvertedResidual(
(branch1): Sequential(
(0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
(4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
...
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(fc): Linear(in_features=1024, out_features=1000, bias=True)
)
Saving and Loading Models with Shapes
当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能想要保存类的结构以及模型,在这种情况下,我们可以将 model (而不是 model.state_dict() ) 传递给保存函数:
torch.save(model, "../../data/ShuffleNetV2_X0.5_eval2.pth")
加载模型如这样:
model = torch.load("../../data/ShuffleNetV2_X0.5_eval2.pth")
print(model)
这种方法在序列化模型时使用 Python pickle 模块,因此它依赖于加载模型时可用的实际类定义。
Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。