TensorFlow 对象检测 API 教程5

TensorFlow 对象检测 API 教程 - 第5部分:保存和部署模型

在本教程的这一步,认为已经选择了预先训练的对象检测模型,调整现有的数据集或创建自己的数据集,并将其转换为 TFRecord 文件,修改模型配置文件并开始训练。但是,现在需要保存模型并将其部署到项目中。

一. 将检查点模型 (.ckpt) 保存为 .pb 文件

回到 TensorFlow 对象检测文件夹,并将 export_inference_graph.py 文件复制到包含模型配置文件的文件夹中。


python export_inference_graph.py --input_type image_tensor --pipeline_config_path ./rfcn_resnet101_coco.config --trained_checkpoint_prefix ./models/train/model.ckpt-5000 --output_directory ./fine_tuned_model

这将创建一个新的目录 fine_tuned_model ,其中模型名为 frozen_inference_graph.pb

二.在项目中使用模型

在本指南中一直在研究的项目是创建一个交通灯分类器。在 Python 中,可以将这个分类器作为一个类来实现。在类的初始化部分中,可以创建一个 TensorFlow 会话,以便在每次需要分类时都不需要创建它。


class TrafficLightClassifier(object):
    def __init__(self):
        PATH_TO_MODEL = 'frozen_inference_graph.pb'
        self.detection_graph = tf.Graph()
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            # Works up to here.
            with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
            self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
            self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
            self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
            self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
        self.sess = tf.Session(graph=self.detection_graph)


在这个类中,创建了一个函数,在图像上运行分类,并返回图像中分类的边界框,分数和类。


def get_classification(self, img):
    # Bounding Box Detection.
    with self.detection_graph.as_default():
        # Expand dimension since the model expects image to have shape [1, None, None, 3].
        img_expanded = np.expand_dims(img, axis=0)  
        (boxes, scores, classes, num) = self.sess.run(
            [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
            feed_dict={self.image_tensor: img_expanded})
    return boxes, scores, classes, num

此时,需要过滤低于指定分数阈值的结果。结果自动从最高分到最低分,所以这相当容易。用上面的函数返回分类结果,做完以上这些就完成了!

下面可以看到交通灯分类器在行动

TensorFlow 对象检测 API 教程5_第1张图片

你可能感兴趣的:(TensorFlow 对象检测 API 教程5)