基于 GitHub xxradon/PytorchToCaffe 源码,修改 example\resnet_pytorch_2_caffe.py 如下
import os
import sys
sys.path.insert(0, '.')
import torch
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe
"""
resnet models in pytorch format can be downloaded from
‘resnet18’: ‘https://download.pytorch.org/models/resnet18-5c106cde.pth’,
‘resnet34’: ‘https://download.pytorch.org/models/resnet34-333f7ec4.pth’,
‘resnet50’: ‘https://download.pytorch.org/models/resnet50-19c8e357.pth’,
‘resnet101’: ‘https://download.pytorch.org/models/resnet101-5d3b4d8f.pth’,
‘resnet152’: ‘https://download.pytorch.org/models/resnet152-b121ed2d.pth’,
"""
def show_usage(cmd):
print( "Usage:" )
print( " ", cmd, " " )
def main(cmd, argv):
if( len(argv) < 2 ):
print( "Error! Parameter is not enough." )
show_usage( cmd )
exit( 1 )
model_name = argv[0]
input_file = argv[1]
pure_path = os.path.splitext( input_file )
file_name = pure_path[0]
print( " model : ", model_name )
print( " input : ", input_file )
print( " output : ", '{}.prototxt'.format(file_name) )
print( " ", '{}.caffemodel'.format(file_name) )
input=torch.ones([1,3,224,224])
match model_name:
case "resnet18":
resnet_x = resnet.resnet18()
case "resnet34":
resnet_x = resnet.resnet34()
case "resnet50":
resnet_x = resnet.resnet50()
case "resnet101":
resnet_x = resnet.resnet101()
case "resnet152":
resnet_x = resnet.resnet152()
case _:
print( "Error! Unknown model name : ", model_name )
show_usage( cmd )
exit( 2 )
if( False == os.path.isfile(input_file) ):
print( "Error! Cannot find input file : ", input_file )
show_usage( cmd )
exit( 3 )
checkpoint = torch.load(input_file)
resnet_x.load_state_dict(checkpoint)
resnet_x.eval()
pytorch_to_caffe.trans_net(resnet_x,input,model_name)
pytorch_to_caffe.save_prototxt('{}.prototxt'.format(file_name))
pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(file_name))
if __name__ == "__main__":
main(sys.argv[0], sys.argv[1:])
脚本依赖pytorch,安装之。
pip install torch
运行中遇到 protobuf 版本过高问题,降级处理
pip install -U protobuf==3.20
下载 resnet model文件后,执行脚本
python example\resnet_pytorch_2_caffe.py resnet152 resnet152-b121ed2d.pth