tf 2.x saved_model.pb 转 frozen_model.pb

从官网下载的预训练模型faster_rcnn
下载后faster_rcnn_resnet101_v1_640x640_1目录如下:
tf 2.x saved_model.pb 转 frozen_model.pb_第1张图片
可以直接用network=tf.keras.models.load_model(pb_file_path)加载模型,再转化
完整代码

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

# 模型path
pb_file_path = './faster_rcnn_resnet101_v1_640x640_1'
# 图片path
image_path = './xx.jpg'

# 定义输入格式
img = tf.io.read_file(image_path)
img = tf.image.decode_jpeg(img, channels=3)  # shape=(450, 600, 3)
img = tf.image.resize(img, (299, 299))
img = tf.expand_dims(img, axis=0) #  shape=(1, 450, 600, 3)
img=tf.cast(img,dtype=tf.uint8)

# 加载模型
network = tf.keras.models.load_model(pb_file_path)

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: network(x))

full_model = full_model.get_concrete_function(
tf.TensorSpec(img.shape, img.dtype)) # (1, 299, 299, 3) 

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
		logdir="./frozen_models",
		name="frozen_graph.pb",
		as_text=False)

tesnsorflow 加载 frozen_model.pb

import cv2
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
path='frozen_graph.pb'
# 图片path
image_path = './xx.jpg'
img_cv2 = cv2.imread(image_path) # (450, 600, 3)

# (1, 3, 299, 299)
blob = cv2.dnn.blobFromImage(img_cv2,
                                scalefactor=1.0 / 255,
                                size=(299, 299),
                                mean=(0, 0, 0),
                                swapRB=False,
                                crop=False)
# (1, 299, 299, 3)
blob = np.transpose(blob, (0,2,3,1)) # 适合faster_rcnn_resnet101_v1_640x640_1网络自定义输入格式

with tf.gfile.FastGFile(path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
with tf.Session() as sess:
    # Restore session
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

    # Run the model detection_scores:Identity_6    detection_boxes:Identity_5    detection_classes:Identity_2
    out = sess.run([sess.graph.get_tensor_by_name('Identity_6:0'),
                    sess.graph.get_tensor_by_name('Identity_5:0'),
                    sess.graph.get_tensor_by_name('Identity_2:0')],
                   feed_dict={'x:0': blob})
    out={'detection_scores':out[0],'detection_boxes':out[1],'detection_classes':out[2]}

你可能感兴趣的:(其他)