深度学习之目标检测object_detection代码实现

基于tensorflow的object_detection框架和slim框架,实现一个目标检测系统:

一:数据及准备

深度学习之目标检测object_detection代码实现_第1张图片

1.数据标注,使用labelImg对数据集进行标注,生成对应的xml文件

2.使用create_pet_tf_record.py脚本生成tfrecord文件,训练集train和验证集val

3.labels_items.txt设定物体类别

4. https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 中下载ssd_mobilenet_v1_coco预模型下载,解压成model.ckpt.data-00000-of-00001和model.ckpt.index和model.ckpt.meta

5.随机选择测试图片test.jpg

6.models/research/object_detection/samples/configs/ssd_mobilenet_v1_pets.config下config文件,修改符合要求

 

二:训练

运行train.py,出现各种问题,可能与tensorflow版本有关,

python ./object_detection/train.py --train_dir=$train_dir --pipeline_config_path=$pipeline_config_path

  1. 修改object_detection\export.py代码中的第72行的参数layout_optimizer替换为optimize_tensor_layout,可参考https://github.com/tensorflow/models/pull/3106
  2. 修改object_detection\data_decoders\tf_example_decoder.py中的dct_method=dct_method,要删去,参考 https://github.com/tensoflow/tensorflow/issues/17208
  3. 在export_inference_graph.py运行中,报tf.float32!=tf.in32的数据类型匹配错误。需要变换logit_scale的类型,改为tf.constant([[logit_scale]],tf.float32),参考https://github.com/tensorflow/models/issues/2774
  4. https://github.com/tensorflow/models/tree/0375c800c767db2ef070cee1529d8a50f42d1042

验证:运行eval.py

python ./object_detection/eval.py --checkpoint_dir=$checkpoint_dir --eval_dir=$eval_dir --pipeline_config_path=$pipeline_config_path

导出模型:运行export_inference_graph.py

python ./object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path $pipeline_config_path --trained_checkpoint_prefix $train_dir/model.ckpt-$current  --output_directory $output_dir/exported_graphs

inference:

python ./inference.py --output_dir=$output_dir --dataset_dir=$dataset_dir

 

输出测试结果图片

你可能感兴趣的:(代码实现)