Tensorflow object detection的安装请参考链接【Tensorflow】安装tensorflow object detection API。
1. 下载VOC数据集
到官网下载VOC数据集。数据集的目录结构如下:
2. 制作tfrecord
在models/research/object_detection/dataset_tools下有一个create_pascal_tf_record.py脚本,运行这个脚本可以直接将VOC数据集转换成tfrecord格式的数据。
python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=train --output_path=./data/pascal_train.record
python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=val --output_path=./data/pascal_val.record
在data目录下生成了两个record文件。
3. 下载预训练权重
下载地址
4. 修改config文件
在models/research/object_detection/samples/configs/目录下将ssd_mobilenet_v2_coco.config复制一份重命名为ssd_mobilenet_v2_pascal.config,修改以下几个地方:
第9行
num_classes: 90
修改为
num_classes: 20
156行
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
替换成自己的路径
fine_tune_checkpoint: "/home/models/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"
第173行
train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
修改为
train_input_reader: {
tf_record_input_reader {
input_path: "/home/models/research/object_detection/data/pascal_train.record"
}
label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
}
修改182行num_examples你验证集的图像数量。
第187行
val_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-00010"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
shuffle: false
num_readers:1
}
修改为
val_input_reader: {
tf_record_input_reader {
input_path: "/home/models/research/object_detection/data/pascal_train.record"
}
label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
shuffle: false
num_readers:1
}
5. 训练
python object_detection/model_main.py --logtostderr --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --model_dir=/home/data/VOCdekit/ssd_mobilnet_v2 --num_train_steps=50000 --num_eval_steps=500
6. 训练可视化
新版的代码在训练时terminal可能会卡住没有输出,不过没关系,可以在tensorboard中查看训练情况。
tensorboard --logdir=/home/data/VOCdekit/ssd_mobilnet_v2
把终端输出的http://xxxxxx复制到浏览器中打开
7. 固化权重
python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --trained_checkpoint_prefix=/home/data/VOCdekit/ssd_mobilnet_v2/model.ckpt-50000 --output_directory=/home/data/VOCdekit/ssd_mobilnet_v2_pascal
生成如下文件