基于http://bbs.gpuworld.cn/index.php?topic=73098.msg84029来做
问题1:download_detection_model及build_detection_graph API找不到定义
不清楚tensorflow对象检测API里面是否有这两个函数,使用find找不到,解决方法,使用tf_trt_models代替
问题2:在jetson nano上导入tf-trt优化后的graph时,出现以下错误“tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered ‘TRTEngineOp’ in binary running on jetson-nano.”
原因是trt未import,加入"import tensorflow.contrib.tensorrt as trt"即可
问题3:基于mobilenet+SSD v1的推理速度在100ms~170ms之间
将jetson nano的用户图形界面关闭后,一帧画面的推理速度可以到50~60ms
/*关闭用户图形界面*/
sudo systemctl set-default multi-user.target
sudo reboot
/*开启用户图形界面*/
sudo systemctl set-default graphical.target
sudo reboot
/******************create_trt_graph.py***************************/
from tf_trt_models.detection import download_detection_model
from tf_trt_models.detection import build_detection_graph
import tensorflow.contrib.tensorrt as trt
config_path, checkpoint_path = download_detection_model('ssd_mobilenet_v1_coco')
frozen_graph, input_names, output_names = build_detection_graph(
config=config_path,
checkpoint=checkpoint_path,
score_threshold=0.3,
batch_size=1
)
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=output_names,
max_batch_size=1,
max_workspace_size_bytes=1 << 25,
precision_mode='FP16',
minimum_segment_size=50
)
with open('./data/trt_graph.pb', 'wb') as f:
f.write(trt_graph.SerializeToString())
f.close()
import tensorflow as tf
import cv2
from IPython.display import Image as DisplayImage
import tensorflow.contrib.tensorrt as trt
import numpy as np
import os
import time
from PIL import Image
def get_frozen_graph(graph_file):
"""Read Frozen Graph file from disk."""
with tf.gfile.FastGFile(graph_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
# The TensorRT inference graph file downloaded from Colab or your local machine.
pb_fname = "./data/trt_graph.pb"
trt_graph = get_frozen_graph(pb_fname)
input_names = ['image_tensor']
# Create session and load graph
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
#tf_sess = tf.Session(config=tf_config)
with tf.Session(config=tf_config) as tf_sess:
tf.contrib.resampler
tf.import_graph_def(trt_graph, name='')
tf_input = tf_sess.graph.get_tensor_by_name(input_names[0] + ':0')
tf_scores = tf_sess.graph.get_tensor_by_name('detection_scores:0')
tf_boxes = tf_sess.graph.get_tensor_by_name('detection_boxes:0')
tf_classes = tf_sess.graph.get_tensor_by_name('detection_classes:0')
tf_num_detections = tf_sess.graph.get_tensor_by_name('num_detections:0')
#IMAGE_PATH = "./data/image1.jpg"
#image = cv2.imread(IMAGE_PATH)
video_path = "./data/road.mp4"
vid = cv2.VideoCapture(video_path)
while True:
return_value, frame = vid.read()
if return_value:
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
else:
raise ValueError("No Image!")
image = cv2.resize(frame, (300, 300))
prev_time = time.time()
scores, boxes, classes, num_detections = tf_sess.run([tf_scores, tf_boxes, tf_classes, tf_num_detections], feed_dict={tf_inp
ut: image[None, ...]})
boxes = boxes[0] # index by 0 to remove batch dimension
scores = scores[0]
classes = classes[0]
num_detections = int(num_detections[0])
curr_time = time.time()
exec_time = curr_time - prev_time
print("time:%.2f ms" % (1000 * exec_time))
# Boxes unit in pixels (image coordinates).
boxes_pixels = []
for i in range(num_detections):
# scale box to image coordinates
box = boxes[i] * np.array([image.shape[0],image.shape[1], image.shape[0], image.shape[1]])
box = np.round(box).astype(int)
boxes_pixels = np.append(boxes_pixels,box)
boxes_pixels = boxes_pixels.reshape(int(len(boxes_pixels)/4),4)
# print("boxes_pixels")
# print(boxes_pixels)
# print("scores:")
# print(scores[0])
# print("classes")
# print(classes[0])
# Remove overlapping boxes with non-max suppression, return picked indexes.
pick = tf_sess.run(tf.image.non_max_suppression(boxes=boxes_pixels,scores=scores[:num_detections],iou_threshold=0.1,max_outp
ut_size=5))
print(pick)
for i in pick:
box = boxes_pixels[i]
box = np.round(box).astype(int)
# print("box[%d] (%d,%d,%d,%d)" % (i,box[1],box[0],box[3],box[2]))
# Draw bounding box.
image = cv2.rectangle(image, (box[1], box[0]), (box[3], box[2]), (0, 255, 0), 2)
label = "{}:{:.2f}".format(int(classes[i]), scores[i])
# Draw label (class index and probability).
#draw_label(image, (box[1], box[0]), label)
cv2.imwrite("./data/result.jpg",image)