Win10手把手教你使用Tensorflow Object Detection API(SSD+MobileNetv2训练VOC2007)

项目地址:https://github.com/tensorflow/models

这次可以和大家分享的是Tensorflow Object Detection API的简单使用方法,以SSD+MobileNetv2的测试以及训练来举例,尽量以简单的语言描述Tensorflow Object Detection API的使用方法。

Step1:环境配置

需要的环境可以参考installation.md,但是里面没有写win10的配置方法,win10下我还是用Anaconda来配置

# 创建虚拟环境
conda create -n ssd python=3.6

# 激活环境
conda activate ssd

# 配置环境
conda install tensorflow-gpu==1.15.0
conda install Cython
conda install contextlib2
conda install pillow
conda install lxml
conda install jupyter
conda install matplotlib

安装了以上的包之后,还有最后一步需要安装pycocotools,不安装这个是没有办法启动训练程序的

具体win10下pycocotools的安装方法可以参考“win10安装pycocotools”

Step2:使用Protoc生成代码

下载好项目之后,cd到D:\...\tensorflow\models\research这样的路径下(根据自己情况修改)

同时准备好protoc工具

protoc下载地址:链接:https://pan.baidu.com/s/1FJsrFVYBtG-cT6mnuKznOw  提取码:wrtb 

找到protoc的路径,执行命令:

D:/.../bin/protoc object_detection/protos/*.proto --python_out=.

Step3:设置路径

在Anaconda文件夹中env/ssd/Lib/site-packages/路径下(根据自己情况修改)

新建一个txt文件并改名为tensorflow_model.pth

内容添加为tensorflow中文件夹的路径(根据自己情况修改)

D:\python\tensorflow\models\research
D:\python\tensorflow\models\research\slim
D:\python\tensorflow\models\research\object_detection

Step4:测试环境是否配置成功

 python object_detection/builders/model_builder_test.py

Win10手把手教你使用Tensorflow Object Detection API(SSD+MobileNetv2训练VOC2007)_第1张图片

Step5:检测一张图片

import os
import sys
import tarfile

import cv2 as cv
import numpy as np
import tensorflow as tf
from utils import label_map_util
from utils import visualization_utils as vis_util


MODEL_NAME = 'ssd_mobilenet_v2_coco_2018_03_29'   # 预训练模型的路径,不用解压
MODEL_FILE = 'D:/python/tensorflow/' + MODEL_NAME + '.tar.gz'

PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

PATH_TO_LABELS = os.path.join('D:/python/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90    # coco数据90类
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categorys = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
categorys_index = label_map_util.create_category_index(categorys)

def load_image_into_numpy(image):
    (im_w, im_h) = image.size
    return np.array(image.getdata()).reshape(im_h, im_w, 3).astype(np.uint8)

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        image = cv.imread("D:/python/tensorflow/models/research/object_detection/test_images/image3.jpg")
        print(image.shape)
        image_np_expanded = np.expand_dims(image, axis=0)
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        scores= detection_graph.get_tensor_by_name('detection_scores:0')
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')

        (boxes, scores, classes, num_detections) = sess.run([boxes, scores, classes, num_detections],
                                                            feed_dict={image_tensor: image_np_expanded})
        vis_util.visualize_boxes_and_labels_on_image_array(
            image,
            np.squeeze(boxes),
            np.squeeze(classes.astype(np.int32)),
            np.squeeze(scores),
            categorys_index,
            use_normalized_coordinates=True,
            line_thickness=4
        )
        cv.namedWindow("enhanced", 0)
        cv.resizeWindow("enhanced", 640, 750)
        cv.imshow("enhanced", image)
        cv.waitKey(0)
        cv.destroyAllWindows()

这就是检测一张图片的代码,里面用到了opencv读图,没有的话可以用pip安装一下,注意不要使用conda install

Step6:生成VOC2007的TF-Record

准备好VOC2007的数据集放在合适的路径下,这里我用训练集来举例生成TF-Record

python object_detection/dataset_tools/create_pascal_tf_record.py 
--label_map_path=D:/tensorflow/dataset/pascal_label_map.pbtxt 
--data_dir=D:/tensorflow/dataset/VOCdevkit 
--year=VOC2007 
--set=train 
--output_path=D:/tensorflow/dataset/tfrecord/pascal_train.record

Step7:修改配置文件

配置文件路径:\...\research\object_detection\samples\configs\ssd_mobilenet_v2_coco.config

修改1:class的数目(coco的90修改成voc的20)

根据文件提示ctrl+F搜索"PATH_TO_BE_CONFIGURED" ,将所有这个字段的地方都修改成相应路径

修改2:预训练权重路径fine_tune_checkpoint: (解压指向model.ckpt)

修改3:修改train_input_reader和val_input_reader,修改成数据集的路径

train_input_reader: {
  tf_record_input_reader {
    input_path: "D:/python/tensorflow/dataset/output/pascal_train.record"
  }
  label_map_path: "D:/python/tensorflow/dataset/output//pascal_label_map.pbtxt"
}

Step8:启动训练

 python object_detection/model_main.py 
--pipeline_config_path=D:/python/tensorflow/dataset/output/ssd_mobilenet_v2_coco.config 
--model_dir=D:\python\tensorflow\dataset\output\model --num_train_steps=10 
--num_eval_steps=5 
--alsologtostderr

 

你可能感兴趣的:(编程学习)