转换 pytorch 格式模型为 caffe格式模型 pth2caffemodel

基于 GitHub xxradon/PytorchToCaffe 源码,修改 example\resnet_pytorch_2_caffe.py 如下

import os
import sys
sys.path.insert(0, '.')

import torch
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe


"""
    resnet models in pytorch format can be downloaded from
        ‘resnet18’: ‘https://download.pytorch.org/models/resnet18-5c106cde.pth’,
        ‘resnet34’: ‘https://download.pytorch.org/models/resnet34-333f7ec4.pth’,
        ‘resnet50’: ‘https://download.pytorch.org/models/resnet50-19c8e357.pth’,
        ‘resnet101’: ‘https://download.pytorch.org/models/resnet101-5d3b4d8f.pth’,
        ‘resnet152’: ‘https://download.pytorch.org/models/resnet152-b121ed2d.pth’,

"""

def show_usage(cmd):
    print( "Usage:" )
    print(   "    ", cmd, "   " )
    
def main(cmd, argv):
    if( len(argv) < 2 ):
        print( "Error! Parameter is not enough." )
        show_usage( cmd )
        exit( 1 )

    model_name = argv[0]
    input_file = argv[1]

    pure_path = os.path.splitext( input_file )
    file_name = pure_path[0]
    
    print( " model  : ",  model_name )
    print( " input  : ",  input_file )
    print( " output : ",  '{}.prototxt'.format(file_name) )
    print( "          ",  '{}.caffemodel'.format(file_name) )
    
    
    input=torch.ones([1,3,224,224])
    match model_name:
        case "resnet18":
            resnet_x = resnet.resnet18()
        case "resnet34":
            resnet_x = resnet.resnet34()
        case "resnet50":
            resnet_x = resnet.resnet50()
        case "resnet101":
            resnet_x = resnet.resnet101()
        case "resnet152":
            resnet_x = resnet.resnet152()
        case _:
            print( "Error! Unknown model name : ",  model_name )
            show_usage( cmd )
            exit( 2 )

    if( False == os.path.isfile(input_file) ):
        print( "Error! Cannot find input file : ", input_file )
        show_usage( cmd )
        exit( 3 )

    checkpoint = torch.load(input_file)
    
    resnet_x.load_state_dict(checkpoint)
    resnet_x.eval()
    pytorch_to_caffe.trans_net(resnet_x,input,model_name)
    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(file_name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(file_name))
    
    
if __name__ == "__main__":
   main(sys.argv[0], sys.argv[1:])

脚本依赖pytorch,安装之。

pip install torch

运行中遇到 protobuf 版本过高问题,降级处理

pip install -U protobuf==3.20 

下载 resnet model文件后,执行脚本

python example\resnet_pytorch_2_caffe.py  resnet152  resnet152-b121ed2d.pth

你可能感兴趣的:(pytorch,caffe,python)