使用opencv python导入tensorflow训练的Object Detection模型并进行预测

        最近在使用tensorflow训练Object Detection的模型,训练好的模型想用于视频中物体的识别和跟踪,由于opencv用于视频和图片的处理非常方便,所以想用opencv直接导入tensorflow训练好的模型。查了一下opencv从3.3版本开始正式支持DNN,可以直接导入caffe、tensorflow等框架训练好的模型,进而完成识别、检测等任务。

        opencv加载tensorflow训练好的模型,采用readNetFromTensorflow(model,config),第一个参数对应训练好的模型文件frozen_inference_graph.pb的路径,第二个参数对应于一个生成的config文件,它其实是一个protobuf格式的文本的网络结构定义,下文会讲如何生成。加载完之后,使用blobFromImage函数,将图片转换成blob格式,网络接收输入数据后,通过forward()函数进行前向传播,即可得到网络输出的结果,检测视频其实只需要对视频中每一帧进行检测,即可得到对视频的检测结果。

        先讲一下如何生成protobuf格式的网络结构定义config文件,opencv提供了转换脚本,如下:

  • tf_text_graph_ssd.py
  • tf_text_graph_faster_rcnn.py
  • tf_text_graph_mask_rcnn.py

        首先根据你选取的网络模型,选择对应的脚本,我用的是ssd_mobilenet_v2的。这个脚本有三个参数,第一个是你训练好的frozen_inference_graph.pb的路径,第二个是训练时使用的pipeline_config文件的路径,第三个就是config文件的输出路径了,如下:

python tf_text_graph_ssd.py --input /path/to/model.pb --config /path/to/example.config --output /path/to/graph.pbtxt

有了graph.pbtxt这个文件,我们就可以用opencv的readNetFromTensorflow导入训练好的模型了,具体如下:

import cv2 as cv

cvNet = cv.dnn.readNetFromTensorflow('frozen_inference_graph.pb', 'graph.pbtxt')

img = cv.imread('example.jpg')
rows = img.shape[0]
cols = img.shape[1]
cvNet.setInput(cv.dnn.blobFromImage(img, size=(300, 300), swapRB=True, crop=False))
cvOut = cvNet.forward()

for detection in cvOut[0,0,:,:]:
    score = float(detection[2])
    if score > 0.3:
        left = detection[3] * cols
        top = detection[4] * rows
        right = detection[5] * cols
        bottom = detection[6] * rows
        cv.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (23, 230, 210), thickness=2)
        cv.putText(img, str(score), (int(right), int(bottom)), cv.FONT_HERSHEY_SIMPLEX, 1, (23, 230, 210), 2)

cv.imshow('img', img)
cv.waitKey()

效果如下:

使用opencv python导入tensorflow训练的Object Detection模型并进行预测_第1张图片

 

其实也可以用opencv读取图片或视频,直接用tensorflow进行检测,如下:

import numpy as np
import tensorflow as tf
import cv2 as cv

# Read the graph.
with tf.gfile.FastGFile('frozen_inference_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Session() as sess:
    # Restore session
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

    # Read and preprocess an image.
    img = cv.imread('example.jpg')
    rows = img.shape[0]
    cols = img.shape[1]
    inp = cv.resize(img, (300, 300))
    inp = inp[:, :, [2, 1, 0]]  # BGR2RGB

    # Run the model
    out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
                    sess.graph.get_tensor_by_name('detection_scores:0'),
                    sess.graph.get_tensor_by_name('detection_boxes:0'),
                    sess.graph.get_tensor_by_name('detection_classes:0')],
                   feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})

    # Visualize detected bounding boxes.
    num_detections = int(out[0][0])
    for i in range(num_detections):
        classId = int(out[3][0][i])
        score = float(out[1][0][i])
        bbox = [float(v) for v in out[2][0][i]]
        if score > 0.3:
            x = bbox[1] * cols
            y = bbox[0] * rows
            right = bbox[3] * cols
            bottom = bbox[2] * rows
            cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2)

cv.imshow('TensorFlow MobileNet-SSD', img)
cv.waitKey()

 

你可能感兴趣的:(Python,tensorflow,Object,Detection)