tensorflow 2.3.1
pytorch 1.6.0
onnxruntime 1.8.1
cv2 4.5.3
onnx_tf 1.8.0
onnx 1.10.1
import cv2
import numpy as np
import torch.onnx
import onnxruntime
import random
# 为了保证pytorch每次输出结果相同
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manaual_seed(seed)
torch.cuda.manaual_seed(seed)
torch.cuda.manaual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False
def get_img_batch(img_path):
# 具体预处理过程应根据模型训练代码具体确定,保持一致
input_size = 224
expand_size = int(input_size/0.875)
img = cv2.imread(img_path)
img = img[:,:,::-1]
w,h = img.shape[1],img.shape[0]
# equals to: transform.Resize(int), resize short side to int, keep ratio
if w >= h:
ratio = w / h
w_ = expand_size * ratio
h_ = expand_size
else:
ratio = h / w
w_ = expand_size
h_ = expand_size * ratio
h_,w_ = int(h_),int(w_)
img = cv2.resize(img, (w_,h_)) # 注意顺序
# equals to: transforms.CenterCrop(int), center square crop
w, h = img.shape[1],img.shape[0]
midx,midy=int(w/2),int(h/2)
cropx,cropy=int(input_size/2),int(input_size/2)
img = img[midy-cropy:midy+cropy, midx-cropx:midx+cropx]
# normalize
mean = torch.tensor([0.485*255,0.456*255,0.406*255]).view(1,3,1,1)
std = torch.tensor([0.229*255,0.224*255,0.225*255]).view(1,3,1,1)
img_batch = torch.from_numpy(img).float().unsqueeze(0) # 'float32' and expand dims
img_batch = img_batch.permute(0,3,1,2)
img_batch = img_batch.sub_(mean).div_(std)
return img_batch
def load_torch_model(backbone_path):
pretrained_dict = torch.load(backbone_path)
net = models.__dict__['mobilenetv2'](width_mult=1.0)
model_dict = net.state_dict()
pretrained_dict = {k:v for k,v in pretrained_dict.items() if (k in model_dict)}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
net.eval() # 重要!为了保证pytorch每次输出结果相同
return net
def torch_to_onnx(torch_model):
batch_size = 1
input_shape = (3,224,224)
x = torch.ones(batch_size, *input_shape)
onnx_path = 'model.onnx'
# export and save the model
torch.onnx.export(
torch_model,
x,
onnx_path,
opset_version=12,
input_names = ['input'],
output_names = ['output'],
)
# 对比测试结果
def compare_torch_onnx(torch_model,onnx_sess,img_batch):
sess_out = onnx_sess.run(None, {'input': img_batch.numpy()})
sess_out = sess_out[0].flatten()
sess_out = np.array(sess_out, dtype='float32')
sess_out = torch.from_numpy(sess_out) # output feature
onnx_pred = torch.nn.functional.softmax(sess_out, dim=0)
onnx_index = np.argmax(onnx_pred).item() # output class index
torch_pred = torch_model(img_batch).detach().flatten() # feature
torch_pred = torch.nn.functional.softmax(torch_pred, dim=0)
torch_pred = np.array(torch_pred, dtype='float32')
torch_index = np,argmax(torch_pred).item() # index
# 判断转换前后特征值差异
np.testing.assert_almost_equal(torch_pred, onnx_pred, decimal=6)
if __name__ == '__main__':
set_seed()
backbone_pth = 'model.pth.tar'
onnx_model = onnxruntime.InferenceSeesion('model.onnx', None)
torch_model = load_torch_model(backbone_pth)
img_path = '1.jpg'
img_batch = get_img_batch(img_path)
# evaluation
import onnx
from onnx_tf.backend import prepare
filename = 'model.onnx'
target_file_path = './tfmodel'
# load onnx model
onnx_model = onnx.load(filename)
tf_rep = prepare(onnx_model)
# save tf model to the path
tf_rep.export_graph(target_file_path)
# 因为上一步保存的模型文件已经是pb格式了,所以不用先转为pb,如果不是pb格式,参考:https://blog.csdn.net/qxqxqzzz/article/details/119668426?spm=1001.2014.3001.5501
def tf_tflite():
tf_model_path, tflite_model_path = './tfmodel', 'model.tflite'
converter = tf.lite.TFLiteCOnverter.from_saved_model(tf_model_path)
converter.target_spec,supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINGS,tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
with open(tflite_model_path, 'wb') as g:
g.write(tflite_model)
def tflite_prediction(img_batch):
tflite_model = 'model.tflite'
interpreter = tf.lite.Interpreter(model_path = tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter,set_tensor(input_details[0]['index'], img_batch)
interpreter.invoke()
tflite_pred = interpreter.get_tensor(output_details[0]['index']) # output feature
tflite_pred = tf.convert_to_tensor(tflite_pred)
tflite_pred = tf.nn.softmax(tflite_pred)
print(tf.argmax(tflite_pred, 1)) # output class index