模型转换:pytorch模型转onnx, onnx转tensorflow, tensorflow转tflite

文章目录

  • 软件版本:
  • pytorch模型转onnx
  • onnx模型转tensorflow
  • tensorflow模型转tflite

软件版本:

tensorflow 2.3.1
pytorch 1.6.0
onnxruntime 1.8.1
cv2 4.5.3
onnx_tf 1.8.0
onnx 1.10.1

pytorch模型转onnx

import cv2
import numpy as np
import torch.onnx
import onnxruntime
import random

# 为了保证pytorch每次输出结果相同
def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manaual_seed(seed)
    torch.cuda.manaual_seed(seed)
    torch.cuda.manaual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False
 
def get_img_batch(img_path):
    # 具体预处理过程应根据模型训练代码具体确定,保持一致
    input_size = 224
    expand_size = int(input_size/0.875)
    img = cv2.imread(img_path)
    img = img[:,:,::-1]
    w,h = img.shape[1],img.shape[0]
    
    # equals to: transform.Resize(int), resize short side to int, keep ratio
    if w >= h:
        ratio = w / h
        w_ = expand_size * ratio
        h_ = expand_size
    else:
        ratio = h / w
        w_ = expand_size
        h_ = expand_size * ratio
    h_,w_ = int(h_),int(w_)
    img = cv2.resize(img, (w_,h_)) # 注意顺序

    # equals to: transforms.CenterCrop(int), center square crop
    w, h = img.shape[1],img.shape[0]
    midx,midy=int(w/2),int(h/2)
    cropx,cropy=int(input_size/2),int(input_size/2)
    img = img[midy-cropy:midy+cropy, midx-cropx:midx+cropx]
    
    # normalize
    mean = torch.tensor([0.485*255,0.456*255,0.406*255]).view(1,3,1,1)
    std = torch.tensor([0.229*255,0.224*255,0.225*255]).view(1,3,1,1)
    img_batch = torch.from_numpy(img).float().unsqueeze(0) # 'float32' and expand dims
    img_batch = img_batch.permute(0,3,1,2)
    img_batch = img_batch.sub_(mean).div_(std)
    return img_batch
    
def load_torch_model(backbone_path):
    pretrained_dict = torch.load(backbone_path)
    net = models.__dict__['mobilenetv2'](width_mult=1.0)
    model_dict = net.state_dict()
    pretrained_dict = {k:v for k,v in pretrained_dict.items() if (k in model_dict)}
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)
    net.eval() # 重要!为了保证pytorch每次输出结果相同
    return net

def torch_to_onnx(torch_model):
    batch_size = 1
    input_shape = (3,224,224)
    x = torch.ones(batch_size, *input_shape)
    onnx_path = 'model.onnx'
    # export and save the model
    torch.onnx.export(
        torch_model,
        x,
        onnx_path,
        opset_version=12,
        input_names = ['input'],
        output_names = ['output'],
    )

# 对比测试结果
def compare_torch_onnx(torch_model,onnx_sess,img_batch):
    sess_out = onnx_sess.run(None, {'input': img_batch.numpy()})
    sess_out = sess_out[0].flatten()
    sess_out = np.array(sess_out, dtype='float32')
    sess_out = torch.from_numpy(sess_out) # output feature
    onnx_pred = torch.nn.functional.softmax(sess_out, dim=0)
    onnx_index = np.argmax(onnx_pred).item() # output class index

    torch_pred = torch_model(img_batch).detach().flatten() # feature
    torch_pred = torch.nn.functional.softmax(torch_pred, dim=0)
    torch_pred = np.array(torch_pred, dtype='float32')
    torch_index = np,argmax(torch_pred).item() # index
    
    # 判断转换前后特征值差异
    np.testing.assert_almost_equal(torch_pred, onnx_pred, decimal=6)
    
if __name__ == '__main__':
    set_seed()
    backbone_pth = 'model.pth.tar'
    onnx_model = onnxruntime.InferenceSeesion('model.onnx', None)
    torch_model = load_torch_model(backbone_pth)
    img_path = '1.jpg'
    img_batch = get_img_batch(img_path)
    # evaluation 

onnx模型转tensorflow

import onnx
from onnx_tf.backend import prepare

filename = 'model.onnx'
target_file_path = './tfmodel'
# load onnx model
onnx_model = onnx.load(filename)
tf_rep = prepare(onnx_model)
# save tf model to the path
tf_rep.export_graph(target_file_path)

tensorflow模型转tflite

# 因为上一步保存的模型文件已经是pb格式了,所以不用先转为pb,如果不是pb格式,参考:https://blog.csdn.net/qxqxqzzz/article/details/119668426?spm=1001.2014.3001.5501
def tf_tflite():
    tf_model_path, tflite_model_path = './tfmodel', 'model.tflite'
    converter = tf.lite.TFLiteCOnverter.from_saved_model(tf_model_path)
    converter.target_spec,supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINGS,tf.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()
    with open(tflite_model_path, 'wb') as g:
        g.write(tflite_model)

def tflite_prediction(img_batch):
    tflite_model = 'model.tflite'
    interpreter = tf.lite.Interpreter(model_path = tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    interpreter,set_tensor(input_details[0]['index'], img_batch)
    interpreter.invoke()
    tflite_pred = interpreter.get_tensor(output_details[0]['index']) # output feature
    tflite_pred = tf.convert_to_tensor(tflite_pred)
    tflite_pred = tf.nn.softmax(tflite_pred)
    print(tf.argmax(tflite_pred, 1)) # output class index

你可能感兴趣的:(#,DL-基础,#,DL-部署)