【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集

Tensorflow object detection的安装请参考链接【Tensorflow】安装tensorflow object detection API。

1. 下载VOC数据集

到官网下载VOC数据集。数据集的目录结构如下:

【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集_第1张图片

2. 制作tfrecord

在models/research/object_detection/dataset_tools下有一个create_pascal_tf_record.py脚本,运行这个脚本可以直接将VOC数据集转换成tfrecord格式的数据。

python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=train --output_path=./data/pascal_train.record
python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=val --output_path=./data/pascal_val.record

在data目录下生成了两个record文件。

【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集_第2张图片【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集_第3张图片

3. 下载预训练权重

下载地址

4. 修改config文件

在models/research/object_detection/samples/configs/目录下将ssd_mobilenet_v2_coco.config复制一份重命名为ssd_mobilenet_v2_pascal.config,修改以下几个地方:

第9行

num_classes: 90

修改为

num_classes: 20

156行

fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"

替换成自己的路径

fine_tune_checkpoint: "/home/models/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"

第173行

train_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100"
  }
  label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}

修改为

train_input_reader: {
  tf_record_input_reader {
    input_path: "/home/models/research/object_detection/data/pascal_train.record"
  }
  label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
}

修改182行num_examples你验证集的图像数量。

第187行

​
val_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-00010"
  }
  label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
  shuffle: false
  num_readers:1
}​

修改为

​
​
val_input_reader: {
  tf_record_input_reader {
    input_path: "/home/models/research/object_detection/data/pascal_train.record"
  }
  label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
  shuffle: false
  num_readers:1
}​

​

5. 训练

python object_detection/model_main.py --logtostderr --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --model_dir=/home/data/VOCdekit/ssd_mobilnet_v2 --num_train_steps=50000 --num_eval_steps=500

【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集_第4张图片

6. 训练可视化

新版的代码在训练时terminal可能会卡住没有输出,不过没关系,可以在tensorboard中查看训练情况。

tensorboard --logdir=/home/data/VOCdekit/ssd_mobilnet_v2

把终端输出的http://xxxxxx复制到浏览器中打开

【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集_第5张图片

7. 固化权重

python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --trained_checkpoint_prefix=/home/data/VOCdekit/ssd_mobilnet_v2/model.ckpt-50000 --output_directory=/home/data/VOCdekit/ssd_mobilnet_v2_pascal

生成如下文件

你可能感兴趣的:(深度学习,tensorflow,Object,Detection)