PyTorch模型转换为TorchScript格式

最近入坑了PyTorch,在学习PyTorch Mobile的安卓部分。要想将训练好模型迁移到手机上使用,需要将模型转化为TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。

转换的方法有两种,一种是通过追踪转换另一中是通过注释转换,本文使用的是通过追踪转换的方法。


import torch
import torchvision
import torch.nn as nn

# 加载模型(根据自己模型修改结构和参数)
model_ft = torchvision.models.mobilenet_v2()
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, 2, bias=True)
model_ft.load_state_dict(torch.load('/home/well/0.94118mobilenet.pt'))
model_ft.eval()

# 给模型的forward()方法一个示例输入
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model_ft, example)
# 保存模型
traced_script_module.save("/home/well/model.pt")

print("Finished Transformation")

PyTorch模型转换为TorchScript格式_第1张图片
RuntimeError: Error(s) in loading state_dict for MobileNetV2:
Missing key(s) in state_dict: “features.0.0.weight”, … (这里表示在state_dict中找不到这些参数)
Unexpected key(s) in state_dict: “module.features.0.0.weight”, …(预料外的参数)

从上面的错误信息中可以发现:这话是参数不匹配的问题,我们传入模型的参数格式是module.features.XXX而模型需要的参数格式是features.XXX。

经过一番查找研究后发现:原因是在多GPU训练的时候,nn.DataParallel(model)对模型进行了包装,所以使用model.state_dict()保存模型,保存的参数格式是module.features.XXX;而我们在转换的过程中加载模型参数的格式的features.XXX,所以报错。


解决方法:在训练的时候使用model.module.state_dict() 代替原来的 model.state_dict() 保存模型,这样我们就可以将PyTorch模型转换为TorchScript格式。
PyTorch模型转换为TorchScript格式_第2张图片
得到的TorchScript模型:
PyTorch模型转换为TorchScript格式_第3张图片

你可能感兴趣的:(PyTorch,pytorch)