如何在PC端测试转换后的TFLite模型

上篇讲述如何把tensorflow模型转换成tflite模型,用于部署到移动端。

这篇分享如何在PC端对tflite模型进行预测,测试模型是否可用

首先,加载tflite模型,查看模型的输入输出

import numpy as np
import tensorflow as tf
import cv2

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="newModel.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details)

打印结果:

[{'name': 'input_1', 'index': 115, 'shape': array([  1, 224, 224,   3]), 'dtype': , 'quantization': (0.0, 0)}]
[{'name': 'activation_1/truediv', 'index': 6, 'shape': array([    1, 12544,     2]), 'dtype': , 'quantization': (0.0, 0)}]

input details里可看到,需要输入的numpy数组 [1, 224, 224, 3],数据类型为float32。
index:115的意思,我理解是传入数据放置的位置

于是我们需要这里处理下input的数据,为了方便,我把数据预存成csv文件了,取出来能直接使用

input_data = np.loadtxt('C:/Users\WIN10/input.csv',delimiter=',')
input_data = input_data.reshape(1,224,224,3)
input_data = input_data.astype(np.float32)
index = input_details[0]['index']
interpreter.set_tensor(index, input_data)

这样就把input数据设置好了,并且把数据传入网络模型

开始做预测:

interpreter.invoke()

读出预测结果

output_data = interpreter.get_tensor(output_details[0]['index'])
print('output_data shape:',output_data.shape)

output_data就是预测结果的源数据,print出来的shape就是上面output_detail 里的shape[ 1, 12544, 2]

最后还需要对预测的源数据进行后处理,解析出我们要的结果,这里每个工程的都不一样,仅供参考。

output_data = output_data.reshape(224,112)
pr = output_data.reshape(112,112,2).argmax( axis=2 )
seg_img = np.zeros( ( 112 , 112 , 3  ) )
seg_img[:,:,0] += ((pr[:,: ] == 1 )*200).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == 1 )*200).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == 1 )*200).astype('uint8')
cv2.imshow('img',seg_img)
cv2.waitKey(0)

你可能感兴趣的:(AI)