timm 和 torchvision 中的 resnet50

从 timm 和 torchvision 分别加载 resnet50 预训练模型,

import torch
def export_onnx(model_saved, onnx_save_name, input_name='img', output_name='logits'):
    dummy_input = torch.randn(1, 3, 224, 224)
    dynamic_axes = dict()
    dynamic_axes[input_name] = {0:"batch_size"}
    dynamic_axes[output_name] = {0:"batch_size"}
    torch.onnx.export(model_saved, dummy_input, onnx_save_name,
        input_names=[input_name], output_names=[output_name],
        export_params=True, verbose=False, opset_version=12,
        dynamic_axes=dynamic_axes)
    
if __name__ == '__main__':
    import torchvision
    net = torchvision.models.resnet50(pretrained=True)
    export_onnx(net, './resnet50_torchvision.onnx')
    
    import timm
    net = timm.create_model('resnet50', pretrained=True)
    export_onnx(net, './resnet50_timm.onnx')

从 onnx 看,权重是一样的。

timm 和 torchvision 中的 resnet50_第1张图片

 

你可能感兴趣的:(模式识别,深度学习,resnet50)