import tensorflow as tf
def freeze_graph(input_checkpoint, output_graph):
output_node_names = "output" #获取的节点
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
# print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
if __name__ == '__main__':
modelpath="./checkPointModel/model.ckpt"
freeze_graph(modelpath,"frozen.pb")
print("finish!")
import tensorflow as tf
convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen.pb",input_arrays=["train_x"],output_arrays=["output"])
convert.post_training_quantize=True
tflite_model=convert.convert()
open("model.tflite","wb").write(tflite_model)
当需要给定输入数据形式时,给出输入格式:
import tensorflow as tf
path="./fullLayer/"
convert=tf.lite.TFLiteConverter.from_frozen_graph(path+"frozen.pb",input_arrays=["images"],output_arrays=["output"],
input_shapes={"images":[1,540,960,1]})
convert.post_training_quantize=True
tflite_model=convert.convert()
open(path+"quantized_model.tflite","wb").write(tflite_model)
print("finish!")
import tensorflow as tf
import cv2 as cv
import numpy as np
if __name__=="__main__":
test_pb_model=True
test_tflite_model=False
read_cahnge_graph=False
pb_model_path="./fullLayer/frozen.pb"
tflite_model_path = "./fullLayer/quantized_model.tflite" # layer2 fullLayer
input_node_name="iamges"
output_node_name="output"
src_img=cv.imread("1.jpg")
cv.imwrite("src.jpg",src_img)
src_img=cv.resize(src_img,(960,540))
src_img=cv.cvtColor(src_img,cv.COLOR_BGR2YCrCb)[:,:,0]
src_img=src_img/127.5-1
src_img=src_img.astype("float32")
src_img=src_img.reshape((1,540,960,1))
if test_tflite_model:
interpreter=tf.lite.Interpreter(tflite_model_path)
interpreter.allocate_tensors()
input_details=interpreter.get_input_details()
# print(str(input_details))
output_details=interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"],src_img)
interpreter.invoke()
output_data=interpreter.get_tensor(output_details[0]["index"])
result=output_data[0]
result=(result+1)*127.5
result[result>255]=255
result[result<0]=0
result=result.astype(np.uint8)
cv.imshow("result",result)
cv.imwrite("result.jpg", result)
cv.waitKey()
if test_pb_model:
src_img=cv.imread("1.jpg")
src_img=cv.resize(src_img,(960,540))
src_img=cv.cvtColor(src_img,cv.COLOR_BGR2YCrCb)[:,:,0]
src_img=src_img/127.5-1
src_img=src_img.astype("float32")
src_img=src_img.reshape((1,540,960,1))
input_image=tf.placeholder(tf.float32,(1,540,960,1))
with open(pb_model_path,"rb") as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
out_result=tf.import_graph_def(graph_def,input_map={"images:0":input_image},return_elements=["output:0"])
sess=tf.Session()
result=sess.run(out_result,feed_dict={input_image:src_img})
result=result[0][0]
result=(result+1)*127.5
result[result>255]=255
result[result<0]=0
result=result.astype(np.uint8)
cv.imshow("resut",result)
cv.waitKey()
if read_cahnge_graph:
gf=tf.GraphDef()
gf.ParseFromString(open(pb_model_path,"rb").read())
for n in gf.node:
print(n.name + " ===> "+n.op )
import tensorflow as tf
import cv2 as cv
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
minst=input_data.read_data_sets("./mnist/",one_hot=True)
index=100
one_minist=minst.train.images[index]
one_minist_img=one_minist.reshape((1,28,28,1))
print("image real value:{}".format(np.argmax(minst.train.labels[index],0)))
test_image_dir = 'test.png'
#model_path = "./model/quantize_frozen_graph.tflite"
model_path = "model.tflite"
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
# print(str(input_details))
output_details = interpreter.get_output_details()
# print(str(output_details))
src=cv.imread(test_image_dir,cv.IMREAD_GRAYSCALE)
src=cv.resize(src,(28,28))
# cv.imshow("gray_img",one_minist_img.reshape([28,28]))
# cv.waitKey()
# 增加一个维度,变为 [1, 784]
src = np.expand_dims(src, axis=0)
src = np.expand_dims(src, axis=3)
# print(src.shape)
src = src.astype('float32') # 类型也要满足要求
# src=one_minist_img #测试minist数据集
# 填装数据
interpreter.set_tensor(input_details[0]['index'], src)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
# 出来的结果去掉没用的维度
result = np.squeeze(output_data)
# print('result:{}'.format(result))
# 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
print('result:{}'.format((np.where(result == np.max(result)))[0][0]))
1.tensorflow1.12.0版本在pb模型转换为tflite时会出现‘No module named ‘_tensorflow_wrap_toco’’,搜索了下竟然是官方的问题。升级为tf-nightly1.13问题解决了。 在安装f-nightly1.13前先卸载原有的tensorflow版本,安装后可能会遇到numpy.core...无法import情况,先卸载numpy,删除其残留文件,再安装numpy。