官网目标检测模块从安装到实现的所有步骤都在 ./models-master/object_detection/g3doc
首先必须要准备标注好的数据(xml文件),以及训练测试文件目录(train.txt,val.txt),然后用
./models-master/object_detection/create_pascal_tf_record.py文件制作数据.record数据格式的train,val文件
我的命令行是这样的
python /home/saners/Mobilenet/makeTest/createtf.py --label_map_path=/home/saners/Mobilenet/makeTest/own_label_map.pbtxt --data_dir=/home/saners/Mobilenet/makeTest --set=train --output_path=/home/saners/Mobilenet/makeTest/train.record
python /home/saners/Mobilenet/makeTest/createtf.py --label_map_path=/home/saners/Mobilenet/makeTest/own_label_map.pbtxt --data_dir=/home/saners/Mobilenet/makeTest --set=val --output_path=/home/saners/Mobilenet/makeTest/val.record
#1我这里的createtf.py就是create_pascal_tf_record.py,因为用了自己的数据,路径什么的有些不对,做了小小的更改,其实也没改什么,还是贴一下吧
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import os
from lxml import etree
import PIL.Image
import tensorflow as tf
import sys
sys.path.append('/home/saners/Mobilenet/models-master') #这里模块路径找不到,我手动加了一下
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
'(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
# flags.DEFINE_string('image_puted', 'The file to storage image')
FLAGS = flags.FLAGS
SETS = ['train', 'val', 'trainval', 'test']
# YEARS = ['VOC2007', 'VOC2012', 'merged'] #这里注释了,因为是自己的数据而且本来这个地方也可有可无,这是官方下载的数据文件夹也是组成路劲的字符串
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
img_path = os.path.join(cla, image_subdirectory, data['filename']) #这个cla全局变量是我后面定义的我的类别文件夹名字,因为要获取文件夹名组成数据路径
full_path = os.path.join(dataset_directory, img_path)
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = int(data['size']['width'])
height = int(data['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
label_name=cla.split('_')[1]
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
# if FLAGS.year not in YEARS: #这里注释了吧,对应前面注释的year
# raise ValueError('year must be in : {}'.format(YEARS))
data_dir = FLAGS.data_dir
# years = ['VOC2007', 'VOC2012'] #同上
# if FLAGS.year != 'merged':
# years = [FLAGS.year]
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
# for year in years:
# logging.info('Reading from PASCAL %s dataset.', year) #同上
class_file=os.listdir(data_dir)
global cla #这里定义了全局变量cla
for cla in class_file:
examples_path = os.path.join(data_dir, cla, 'ImageSets', 'Main', FLAGS.set + '.txt') # train.txt或者val.txt存放的位置
annotations_dir = os.path.join(data_dir, cla, FLAGS.annotations_dir) #xml文件存放的位置
if os.path.exists(examples_path):
examples_list = dataset_util.read_examples_list(examples_path)
else:
continue
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.app.run()
#2own_label_map.pbtxt 这个文件是我改了./models-master/object_detection/data/pascal_label_map.pbtxt文件中的类别,我知识了一类所以就是这样
item {
id: 1
name: 'aircraft'
}
#3--data_dir就是你放图像数据的跟目录
#4 --set说明数据是train还是val
#5 --output_path 就是输出.record文件的地方,文件名最好是train.record或者val.record,方便辨认
我的目录结构
+dataTest #数据根目录
+category_aircraft #我只有一类,类别根目录,注意这个文件夹里面的文件夹名字请保持一致
+Annotations #里面是所有图像的xml文件
+imageSets #放置train.txt和val.txt文件的地方
+Main #里面是train.txt和val.txt文件,多这一层是因为官方提供的生成数据的文件里面有这个路径字符串,前面文件保持一致也是这个原因,不过你想改也行,记得改源码
train.txt
val.txt
+JPEGImages #里面是图像,必须是JPEG格式的即.jpg
own_label_map.pbtxt #类别文件
我的目录结构是这样的
+example
+data
own_label_map.pbtxt
train.record #训练数据,前面产生了
val.record #测试数据
+models
+model #放置训练结果的地方
ssd_mobilenet_v1_pets.config #模型配置文件,位置在./models-master/object_detection/samples/configs中,这个文件需要更改,你打开以后会有提示
大概说一下num_classes改为你的类别个数,这一句我注释了fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt",因为我是重新训练模型,不是在某个模型基础上开始的,from_detection_checkpoint: true 表示检查点来自检测模型,false表示检查点来自分类模型 这是官方文档原话:`from_detection_checkpoint` is a boolean value. If false, it assumes the checkpoint was from an object classification checkpoint. Note that starting from a detection checkpoint will usually result in a faster training job than a classification checkpoint.(大神帮忙理解一下对不对)
这部分以后的更改都是对应文件的路径,就不在累述了。
num_steps: 200000和num_examples: 2000 一个是训练迭代次数一个是val 的数据量,第二个参数的说明是在其他博客上看到的,不敢保证就是对的,自己再慢慢研究吧!
以上都做完后就可以训练了,我的命令行如下:
python /home/saners/Mobilenet/models-master/object_detection/train.py \
--logtostderr \
--pipeline_config_path=/home/saners/Mobilenet/exampleTest2/models/ssd_mobilenet_v1_pets.config \
--train_dir=/home/saners/Mobilenet/exampleTest2/models/model/
这个训练还在研究中,后续还会更新