Tensorflow为目标检测提供了很好的api,但其调用过程却涉及到环境设置、数据准备等各方面,比较繁杂。本文在ubuntu系统,以tensorflow 1.7为环境,逐步介绍api调用过程的每一个细节,如果小白刚接触这一块,可以按本文逐步实施,训练自己的数据,并得到网络和结果,开始object detection的第一步。(感谢魅哥的友情指导。文章中也引用了其他大侠提供的程序,在此一并表示感谢。)
1.从GitHub上下载Tensorflow的Object Detection API,地址是(http://github.com/tensorflow/model),可以通过git下载(命令 git clone http://github.com/tensorflow/model.git),也可以通过Tortoise来下载。需要注意的是,Tensorflow目前在GitHub上就一个链接,需要到官网去下载。最好出去,不然速度非常慢。(也可以到收集的程序目录中去直接拷贝)
2.下载完毕后,会有一个models文件夹,下面有一个research文件夹。之后的操作以research为根目录,设置的都是相对路径。
3.下载protoc的2.6以上版本,下载地址 https://github.com/google/protobuf/releases,注意下载版本要和系统信息匹配。
4.使用protoc对proto文件进行编译,目的是把research/object_detection/protos目录下的.proto文件编译成.py文件。在系统CMD下执行命令 protoc object_detection/protos/*.proto --python_out=.
运行完成后,在research/object_detection/protos目录下,每一个 .proto文件都会生成对应的.py文件。
5.将Slim文件夹添加到系统的系统的PYTHONPATH环境变量当中。输入命令vim ~/.bashrc 进入文件进行编辑(按i键进入修改模式),输入命令 export PYTHONPYTH='/home/pc/Deep-Learning-21-Examples-master/chapter_5/research/slim',当然,绝对路径自行修改(Esc退出文件, :wq保存)。之后再 source ~/.bashrc 保存设置。可以通过 echo $PYTHONPATH 查看环境变量。
添加环境变量之后,可以通过 python object_detection/builders/model_builder_test.py 进行分测试。
6. 开始训练新的模型,首先是数据准备。
(1)将新的数据集按照 Pascal VOC的格式,拷到object_detection目录下,以ROS1文件夹为例。
数据集文件夹的内部结构为
数据集放在VOCdevkit文件夹下,RSDS2016文件夹的名称可以改,但必须要和.xml文件中的描述一致,不然会报错。
另外在/research/object_detection的create_pascal_tf_record.py文件中,也需要对years参数进行调整
RSDS2016/ 目录下包含三个文件夹,其中JPEGImages/ 是数据集的图片,Annotations/ 是数据集的标签(xml文件),ImageSets/ 目录下有一个Main文件夹,再往下是四个文件夹,包含了训练数据和测试数据的分类信息,也是程序运行时首先读取的数据。
(2)这四个文件在数据集中通常是没有的,需要自己来生成。有file_text.py文件,需要设置其中的绝对路径和比例。该文件放在object_detection目录下。
# -*- coding:utf-8 -*-
'''
该代码是将数据转为VOC2007,ImageSets里所有文件
'''
import os
__author__ ='chendingxin'
_IMAGE_SETS_PATH= '/home/pc/abc/research/object_detection/RED/VOCdevkit/RSDS2016/ImageSets'
_MAin_PATH ='/home/pc/abc/research/object_detection/RED/VOCdevkit/RSDS2016/ImageSets/Main'
_XML_FILE_PATH= '/home/pc/abc/research/object_detection/RED/VOCdevkit/RSDS2016/Annotations'
if __name__ == '__main__':
if os.path.exists(_IMAGE_SETS_PATH):
print('ImageSets dir is already exists')
if os.path.exists(_MAin_PATH):
print('Main dir is already in ImageSets')
else:
os.mkdir(_IMAGE_SETS_PATH)
os.mkdir(_MAin_PATH)
print(_MAin_PATH)
# 测试集
f_test =open(os.path.join(_MAin_PATH,'test.txt'),'w')
# 训练和验证集
f_trainval =open(os.path.join(_MAin_PATH,'trainval.txt'),'w')
# trainval中训练部分
f_train =open(os.path.join(_MAin_PATH,'train.txt'),'w')
# trainval中验证集
f_val =open(os.path.join(_MAin_PATH,'val.txt'),'w')
# 遍历XML文件夹
for root, dirs, files in os.walk(_XML_FILE_PATH):
i =1
j =1
for file in files:
if not(i % 5): # 作为测试集,设置比例
f_test.writelines(str(file).split('.')[0] + '\n')
else: # 训练和验证集
f_trainval.writelines(str(file).split('.')[0]+'\n')
if j % 2: # 训练集,设置比例
f_train.writelines(str(file).split('.')[0]+'\n')
else:
# 验证集,设置比例
f_val.writelines(str(file).split('.')[0]+'\n')
j +=1
i +=1
f_test.close()
f_train.close()
f_trainval.close()
f_val.close()
(3)部分.xml文件含有描述,如下
需要将对.xml进行修改,去掉这一行,程序放在/home/pc/abc/research/object_detection/ROS1/1/ 目录下
#coding=utf-8
import os
import os.path
import xml.dom.minidom
path="/home/pc/abc/ROS/Annotation"
files=os.listdir(path)
s=[]
def file_extension(path):
return os.path.splitext(path)[1]
for xmlFile in files:
if not os.path.isdir(xmlFile):
if file_extension(xmlFile) == '.xml':
print(xmlFile)
with open(xmlFile,"r") as f:
lines = f.readlines()
with open(xmlFile,"w") as f_w:
for line in lines:
if "
(4)在 /object_detection/ROS1目录下的pascal_label_map.pbtxt文件,需要根据分类数目和名称进行调整。
(5)Tensorflow只能读取.tfrecord格式的文档,因此,需要将以上pascal VOC格式的各类文件打包成相应格式,再输入网络。用到的是~/research/object_detection/目录下的 create_pascal_tf_record.py,注意之前提到的,对其中的years一项(2处)进行修改。另外,为了防止部分数据集打标签时不规范,出现了边界溢出的现象而报错,需要对目标标签进行容错。整体程序调整如下
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert raw PASCAL dataset to TFRecord for object_detection.
Example usage:
./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \
--year=VOC2012 \
--output_path=/home/user/pascal.record
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import os
from lxml import etree
import PIL.Image
import tensorflow as tf
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
'(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'ROS1/pascal_label_map.pbtxt',
'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
FLAGS = flags.FLAGS
SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged', 'RSDS2016']
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
"""Convert XML derived dict to tf.Example proto.
Notice that this function normalizes the bounding box coordinates provided
by the raw data.
Args:
data: dict holding PASCAL XML fields for a single image (obtained by
running dataset_util.recursive_parse_xml_to_dict)
dataset_directory: Path to root directory holding PASCAL dataset
label_map_dict: A map from string label names to integers ids.
ignore_difficult_instances: Whether to skip difficult instances in the
dataset (default: False).
image_subdirectory: String specifying subdirectory within the
PASCAL dataset directory holding the actual image data.
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
"""
img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
full_path = os.path.join(dataset_directory, img_path)
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = int(data['size']['width'])
height = int(data['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append((float(obj['bndbox']['xmin']) / width>0) * (float(obj['bndbox']['xmin']) / width))
ymin.append((float(obj['bndbox']['ymin']) / height>0) * (float(obj['bndbox']['ymin']) / height))
xmax.append((float(obj['bndbox']['xmax']) / width<1) * (float(obj['bndbox']['xmax']) / width) + (float(obj['bndbox']['xmax']) / width>=1) )
ymax.append((float(obj['bndbox']['ymax']) / height<1) * (float(obj['bndbox']['ymax']) / height) + (float(obj['bndbox']['ymax']) / height>=1))
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
if FLAGS.year not in YEARS:
raise ValueError('year must be in : {}'.format(YEARS))
data_dir = FLAGS.data_dir
years = ['VOC2007', 'VOC2012', 'RSDS2016']
if FLAGS.year != 'merged':
years = [FLAGS.year]
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
for year in years:
logging.info('Reading from PASCAL %s dataset.', year)
examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
FLAGS.set + '.txt')
annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
examples_list = dataset_util.read_examples_list(examples_path)
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.app.run()
(6)在Terminal中运行2段程序
python3 create_pascal_tf_record1.py --data_dir ROS1/VOCdevkit/ --year=RSDS2016 --set=train --output_path=ROS1/pascal_train.record
python3 create_pascal_tf_record1.py --data_dir ROS1/VOCdevkit/ --year=RSDS2016 --set=val --output_path=ROS1/pascal_val.record
在~/research/object_detection/ROS1/目录下生成 pascal_train.record 和 pascal_val.record 两个文件。
(7)至此,数据准备完毕。
6. 下载现有的模型。此处采用的COCO数据集上训练的 Faster R-CNN + Inception_Resnet_V2. 下载地址为 http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017.tar.gz
下好,解压之后产生5个文件,放在 ~/research/object_detection/ROS1/pretrained/ 目录下
7. 在 ~/research/object_detection/samples/configs/ 目录下,找到模型对应的 .config 文件,将其拷到~/research/object_detection/ROS1/ 目录下,命名为 ROS1.config, 并进行修改
(1)num_classes 分类数由需要训练的数据集决定;
(2)eval_config 中的 num_examples,为验证数据集的大小,与val.txt中的图象个数(行数)相同;
(3)对5处路径进行修改,将“ROS”直接替换;
(4)dropout_keep_probability,防止过拟合;
(5)num_steps总训练步数;
(6)learning_rate设置学习率;
(7)min_dimension、max_dimension 预处理,图像缩放后的最小值和最大值(内存受限时,注意设置)。
8.将 ~/research/object_detection/ROS1/train_dir/ 目录清空,用于存放训练数据。
9.开始训练。在Terminal窗口输入命令
python3 train.py --train_dir ROS1/train_dir/ --pipeline_config_path ROS1/ROS1.config
即开始训练过程
10. 导出模型。训练模型保存在 ~/research/object_detection/ROS1/train_dir/ 中,
在Terminal中输入
python3 export_inference_graph.py --input_type image_tensor --pipeline_config_path ROS1/ROS1.config --trained_checkpoint_prefix ROS1/train_dir/model.ckpt-327 --output_directory ROS1/export/
通过运行export_inference_graph.py文件,将相应的数据导出为对应的模型,保存在 ~/research/object_detection/ROS1/export/ 目录下的 frozen_inference_graph.pb。
目录下的其他文件可以去掉。
11. 准备运行新训练的模型。指定路径到 ~/research/object_detection/,进入jupyter notebook, 打开retain_object_detection_tutorial .ipynb文件
修改其中的路径,
将下载模型的相关内容添加注释
在 ~/research/object_detection/test_images/ 目录下添加待检测图片,并按 "image0.jpg"的形式修改名称。
12. 在jupyter notebook 中运行新模型进行检测,结果实时显示并保存在 ~/research/object_detection/test_images/ 目录,格式 .png