将pytorch的pth文件固化为pt文件

说明

我参考了一个开源的人像语义分割项目mobile_phone_human_matting,这个项目提供了预训练模型,我想要将该模型固化,然后转换格式后在嵌入式端使用。

该项目保存模型的代码如下:

lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model)
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
        }, lastest_out_path)

转换代码

上面代码保存了state_dict, 所以保存的文件中是不含模型结构的,固化时需要从代码构造网络结构。好在项目是完全开源,将原项目下的model目录拷贝过来就行。
另外不能忘记调用eval() 来固化参数。

完整的转换代码如下:

import torch
from model import segnet

ckptfile="./ckpt_lastest.pth"
savedfile="./human_seg.pt"

model = segnet.SegMattingNet()    

device = torch.device('cpu')
ckpt = torch.load(ckptfile, map_location=device )
model.load_state_dict(ckpt['state_dict'])

model.eval() #这一步会将参数固化,不能省。否则会报AssertionError('batchnorm with training is not support. Please set model.eval() before export.')

x = torch.rand(1,3,256,256)
ts = torch.jit.trace(model, x)
ts.save(savedfile)

参考资料

mobile_phone_human_matting

pytorch训练的.pth模型格式转换

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能)