使用Object Detection API训练

使用Object Detection API训练

准备工作

Running Locally中提到准备工作大致有三个:

  1. 安装Tensorflow Object Detection API
  2. 数据集
  3. Object Detection pipeline设置文件

安装Tensorflow Object Detection API

没啥说的,看官网教程Installation

数据集

可以按照Preparing Inputs来准备TFRecord格式的数据集。

当然也可以使用models/research/object_detection/dataset_tools/下的脚本将常见的数据集创建成TFRecord格式的。其中常用的有create_pascal_tf_record.py,就是将安装PASCAL VOC组织的数据转换为TFRecord格式。在使用时,代码其中的’aeroplane_’ + 多余,删去即可。

Object Detection pipeline设置文件

将object_detection/samples/configs/对应的config文件拷贝一份,然后根据实际情况修改。

  1. num_classes:修改为自己的classes num
  2. 将所有PATH_TO_BE_CONFIGURED的地方修改为自己之前设置的路径(共5处)
  3. batch_size根据情况修改,初始设置可能会导致内存不够用。

训练

legacy

train.py和eval.py被移到legacy文件下了。

python object_detection/legacy/train.py \
    --logtostderr \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --train_dir=${TRAIN_DIR}
python object_detection/legacy/eval.py \
    --logtostderr \
    --checkpoint_dir=${TRAIN_DIR} \
    --eval_dir=${EVAL_DIR} \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH}

如果报关于unicode的错,将object_detection\utils\object_detection_evaluation.py下的category_name = unicode(category_name, ‘utf-8’)修改为category_name = str(category_name)

recommend

model_main.py将train和eval结合在一块,官方推荐使用。

python object_detection/legacy/eval.py \
    --logtostderr \
    --checkpoint_dir=${TRAIN_DIR} \
    --eval_dir=${EVAL_DIR} \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH}
  • 添加 tf.logging.set_verbosity(tf.logging.INFO) 到model_main.py 的 import 区域之后,会每隔一百个step输出loss,总比没有好,至少它让你知道它在跑。
  • 如果是python3训练,添加list() 到 model_lib.py的大概390行 category_index.values()变成: list(category_index.values()),否则会有 can’t pickle dict_values ERROR出现
  • 还有一个问题是,用model_main.py 训练时,因为它把老版本的train.py和eval.py集合到了一起,所以制定eval num时指定不好会有warning出现,就像:

导出模型

export INPUT_TYPE=image_tensor
python object_detection/export_inference_graph.py \
    --input_type=${INPUT_TYPE} \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --trained_checkpoint_prefix=${TRAIN_DIR}/`head -n 1  ${TRAIN_DIR}/checkpoint | grep -o -E '\".+\"' | sed s/\"//g` \
    --output_directory=${EXPORT_DIR}

参考

TensorFlow object detection API应用
第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)

你可能感兴趣的:(记录)