Pytorch DataParallel多卡训练模型导出onnx模型

Pytorch模型转换到onnx模型代码如下:

import torch
import torch.nn as nn
import torch.onnx
import onnx
import os
from QualityNet import QualityNet

if __name__ == '__main__':

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model_path = './models/pytorch/face_quality.pth'
    
    state_dict = torch.load(model_path)
    
    model = QualityNet()
    model = nn.DataParallel(model)
    model = QualityNet().to(device)
    model.load_state_dict(state_dict)
    model.eval()

    onnx_path = './models/onnx/face_quality.onnx'
    dummy_input = torch.ones(1, 3, 128, 128,)
    dummy_input = dummy_input.to(device)
    input_names = ["input"]
    output_names = ["output"]
    # export onnx model
    torch.onnx.export(model, dummy_input, onnx_path, verbose=False, opset_version=9, input_names=input_names, output_names=output_names)
    # load onnx model
    onnx_model = onnx.load(onnx_path)
    # check onnx model
    onnx.checker.check_model(onnx_model)
    

运行后会出现如下错误: 

Pytorch采用DataParallel进行多卡训练得到的模型文件直接转换到onnx模型会出现不支持的情况,原因是使用DataParallel进行多卡训练,模型文件中的键值对key值前面会多一个"modules.":

Pytorch DataParallel多卡训练模型导出onnx模型_第1张图片

解决方法很简单,只需要去掉多余的"module."字段即可,重新创建一个OrderedDict,修改模型键值,然后将它载入模型,修改后的代码如下:

import torch
import torch.nn as nn
from collections import OrderedDict
import torch.onnx
import onnx
import os
from QualityNet import QualityNet

if __name__ == '__main__':

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model_path = './models/pytorch/face_quality.pth'
    
    state_dict = torch.load(model_path)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove "module."
        new_state_dict[name] = v
    
    model = QualityNet()
    model = QualityNet().to(device)
    model.load_state_dict(new_state_dict)
    model.eval()

    onnx_path = './models/onnx/face_quality.onnx'
    dummy_input = torch.ones(1, 3, 128, 128,)
    dummy_input = dummy_input.to(device)
    input_names = ["input"]
    output_names = ["output"]
    # export onnx model
    torch.onnx.export(model, dummy_input, onnx_path, verbose=False, opset_version=9, input_names=input_names, output_names=output_names)
    # load onnx model
    onnx_model = onnx.load(onnx_path)
    # check onnx model
    onnx.checker.check_model(onnx_model)

 

 

你可能感兴趣的:(深度学习,pytorch,onnx)