2017年Google公司开放了TensorFlow Object Detection API,该项目实现了多种深度学习框架,包括mask RCNN。利用API根据github指引可以轻松实现语义分割,实例分割等任务。最近有个细胞图像识别的任务,刚好研究一下mask RCNN模型。
首先从github上安装好TensorFlow Object Detection API,下载预训练模型。
git clone --depth 1 https://github.com/tensorflow/models
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .
将Slim加入PYTHONPATH,在research文件夹下,执行一下命令。命令会自动检查API是否正确安装,Tensorflow 2以上版本使用model_builder_tf1_test.py。
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
PYTHONPATH="$PYTHONPATH:models/"
python object_detection/builders/model_builder_tf1_test.py
准备数据
使用labelme软件标注图像数据,安装及使用详细点击链接,github。TensorFlow数据集输入的格式需要.TFRecord格式,要将数据标注的json格式转为tfrecord。https://blog.csdn.net/WellTung_666/article/details/105723640
训练新的模型
TensorFlow Object Detection API训练需要修改一个名为*.config的配置文件,在object_detection/samples/configs/文件夹中有各种模型设置的示例。先将对应模型的.config文件复制到自己的目录,修改其中以下几个地方:
- num_classes,分类物体的类别数。
- num_examples,验证阶段需要执行的图片数量
- fine_tune_checkpoint,预训练模型,mask_rcnn_resnet101_atrous_coco_2018_01_28/model.ckpt
- input_path,两处input_path需要修改,训练和验证数据的tfrecord文件,多个文件可以用 "*" 指定:tfrecord/train-*.tfrecord。
-
label_map_path,指定分类的pbtxt文件。
接下来就可以训练模型了,新建一个training文件夹用于保存模型。
python object_detection/model_main.py
--pipeline_config_path=mask_rcnn_resnet101_atrous_coco.config
--model_dir=training
导出模型并测试单张图片
TensorFlow Object Detection API提供了一个export_inference_graph.py脚本,用来导出模型,model.ckpt-813是指第813次迭代保存的模型,需要跟据训练情况选择合适模型。
python object_detection/export_inference_graph.py
--pipeline_config_path mask_rcnn_resnet101_atrous_coco.config
--trained_checkpoint_prefix training/model.ckpt-813
--output_directory exported_model
参考github源代码可视化单张图测试结果。
python model_test.py