tensorflow2-savedmodel convert to tflite

import tensorflow as tf
from utils.eval_utils import show_box,show_multibox
from utils.anchor_utils import generate_anchors, from_offset_to_box
import os
import cv2
import config as c
import numpy as np
import copy
from utils.aug_utils import color_normalize
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

def pb_to_tflite(saved_model_path,tflite_path):
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=saved_model_path)
    converter.target_spec.support_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()
    with open(tflite_path, 'wb') as g:
        g.write(tflite_model)

def load_tflite(tflite_path,img_path):
    img = cv2.imread(img_path)
    height, width, _ = np.shape(img)
    img_batch = np.array([color_normalize(cv2.resize(copy.copy(img), tuple(c.input_shape[:2])))], dtype=np.float32)

    # load model
    interpreter = tf.lite.Interpreter(model_path=tflite_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    # print(output_details)

    # inference
    interpreter.set_tensor(input_details[0]['index'], img_batch)
    interpreter.invoke()
    # output vector
    loc_pred = interpreter.get_tensor(output_details[0]['index'])
    print(loc_pred.shape) #(1, 8732, 4)
    cls_pred = interpreter.get_tensor(output_details[1]['index'])
    print(cls_pred.shape) #(1, 8732, 5)
    anchors = generate_anchors()
    boxes, scores, labels = from_offset_to_box(loc_pred[0], cls_pred[0], anchors,
                                               anchor_belongs_to_one_class=True, score_threshold=0.1)
    print(boxes, scores, labels)
    boxes_new = []
    score_new = []
    label_new = []
    for box, score, label in zip(boxes, scores, labels):
        box[0] = box[0] / c.input_shape[1] * width  # left
        box[1] = box[1] / c.input_shape[0] * height  # top
        box[2] = box[2] / c.input_shape[1] * width  # right
        box[3] = box[3] / c.input_shape[0] * height  # bottom
        print('image: {}\nclass: {}\nconfidence: {:.4f}\n'.format(img_path, c.class_list[label], score))
        boxes_new.append(box)
        score_new.append(score)
        label_new.append(c.class_list[label])
    show_multibox(img, boxes_new, score_new, label_new)


tflite_path='/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/tflite/ssd.tflite'
saved_model_path= "/result/weight/ssd_vgg/pb/"
img_path='/data1/gyx/QR/SSD_Tensorflow2.0-master/test_pic/image_1636168791399.jpg'

pb_to_tflite(saved_model_path,tflite_path)
load_tflite(tflite_path,img_path)

你可能感兴趣的:(Tensorflow,keras,python,python,tensorflow,开发语言)