Mindspore版本的Mask R-CNN,在GPU上训练。其backbone包括Resnet50和Mobilenetv1。
Mask R-CNN扩展了Faster R-CNN,增加了一个用于预测每个感兴趣区域(Region of Interest, RoI)上的分割掩码的分支,与现有的用于分类和边界框回归的分支并行。掩码分支是一个小的FCN应用于每个RoI,以像素-像素的方式预测一个分割掩码。Mask R-CNN是简单的实现和训练的更快的R-CNN框架,这有利于广泛的灵活的架构设计。此外,掩码分支只增加了很小的计算开销,从而实现了快速系统和快速实验。
没有bells和whistles的情况下,在COCO的实例分割任务上,Mask R-CNN超越了之前所有SOTA的单个模型取得的结果,包括根据2016年实例检测比赛冠军模型精心设计的,在2018年推出的改进作品。
model | training dataset | bbox | segm | ckpt |
---|---|---|---|---|
maskrcnn_coco2017_bbox37.4_segm32.9 | coco2017 | 0.374 | 0.329 | checkpoint/maskrcnn_coco2017_acc32.9.ckpt |
maskrcnnmobilenetv1_coco2017_bbox22.2_segm15.8 | coco2017 | 0.222 | 0.158 | checkpoint/maskrcnnmobilenetv1_coco2017_bbox24.00_segm21.5.ckpt |
请在此处下载 预训练权重文件。并将它放在checkpoint文件夹下。
在这里,我们列出了一些重要的训练参数。此外,您可以查看配置文件的详细信息。
Parameter | Default | Description |
---|---|---|
workers | 1 | Number of parallel workers |
device_target | GPU | Device type |
learning_rate | 0.002 | learning rate |
weight_decay | 1e-4 | Control weight decay speed |
total_epoch | 13 | Number of epoch |
batch_size | 2 | Batch size |
dataset | coco | Dataset name |
pre_trained | ./checkpoint | The path of pretrained model |
checkpoint_path | ./ckpt_0 | The path to save |
下面介绍一下如何使用Mask R-CNN和Mask R-CNN MobileNetV1。
以Mask R-CNN为例,Mask R-CNN MobileNetV1遵循相同的流程。
首先,你需要下载 coco2017 数据集。
也可以直接用下方命令,在服务器端获取coco数据集。
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
然后您可以将它们解压缩到指定的文件夹。
下载完coco2017后,需要修改配置文件中的“coco_root”和其他参数。
获得数据集后,确保你的路径如下:
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
如果您想使用其他数据集,请在运行脚本时将config.py中的“dataset”更改为“other”(自己定义的命名)。
在你开始训练模型之前。数据增强对于您的数据集以及创建训练数据和测试数据是必要的。对于coco数据集,你可以使用dataset.py为图像添加蒙版,并将它们转换到MindRecord。MindRecord是一种指定的数据格式,可以在某些场景下优化MindSpore的性能。
当您应该更改参数和数据集时,您可以在config.py中更改参数和数据集路径。
python train.py
训练之后,您可以使用验证集来评估模型的性能。
运行eval.py来实现这一点。model_type参数的使用与训练过程相同。
python eval.py
Evaluate annotation type bbox
IoU | area | maxDets | Average Precision (AP) |
---|---|---|---|
0.50:0.95 | all | 100 | 0.374 |
0.50 | all | 100 | 0.599 |
0.75 | all | 100 | 0.403 |
0.50:0.95 | small | 100 | 0.235 |
0.50:0.95 | medium | 100 | 0.415 |
0.50:0.95 | large | 100 | 0.474 |
0.50:0.95 | all | 1 | 0.312 |
0.50:0.95 | all | 10 | 0.501 |
0.50:0.95 | all | 100 | 0.530 |
0.50:0.95 | small | 100 | 0.363 |
0.50:0.95 | medium | 100 | 0.571 |
0.50:0.95 | large | 100 | 0.656 |
Evaluate annotation type segm
IoU | area | maxDets | Average Precision (AP) |
---|---|---|---|
0.50:0.95 | all | 100 | 0.329 |
0.50 | all | 100 | 0.555 |
0.75 | all | 100 | 0.344 |
0.50:0.95 | small | 100 | 0.165 |
0.50:0.95 | medium | 100 | 0.357 |
0.50:0.95 | large | 100 | 0.477 |
0.50:0.95 | all | 1 | 0.284 |
0.50:0.95 | all | 10 | 0.436 |
0.50:0.95 | all | 100 | 0.455 |
0.50:0.95 | small | 100 | 0.283 |
0.50:0.95 | medium | 100 | 0.490 |
0.50:0.95 | large | 100 | 0.592 |
Evaluate annotation type bbox
IoU | area | maxDets | Average Precision (AP) |
---|---|---|---|
0.50:0.95 | all | 100 | 0.235 |
0.50 | all | 100 | 0.409 |
0.75 | all | 100 | 0.241 |
0.50:0.95 | small | 100 | 0.145 |
0.50:0.95 | medium | 100 | 0.250 |
0.50:0.95 | large | 100 | 0.296 |
0.50:0.95 | all | 1 | 0.243 |
0.50:0.95 | all | 10 | 0.397 |
0.50:0.95 | all | 100 | 0.418 |
0.50:0.95 | small | 100 | 0.264 |
0.50:0.95 | medium | 100 | 0.449 |
0.50:0.95 | large | 100 | 0.515 |
Evaluate annotation type segm
IoU | area | maxDets | Average Precision (AP) |
---|---|---|---|
0.50:0.95 | all | 100 | 0.191 |
0.50 | all | 100 | 0.350 |
0.75 | all | 100 | 0.190 |
0.50:0.95 | small | 100 | 0.095 |
0.50:0.95 | medium | 100 | 0.204 |
0.50:0.95 | large | 100 | 0.278 |
0.50:0.95 | all | 1 | 0.206 |
0.50:0.95 | all | 10 | 0.315 |
0.50:0.95 | all | 100 | 0.328 |
0.50:0.95 | small | 100 | 0.194 |
0.50:0.95 | medium | 100 | 0.350 |
0.50:0.95 | large | 100 | 0.424 |
最后,您可以使用自己的图像来测试训练后的模型。将图像放入images文件夹,然后运行inference. py进行推断。
python infer.py
[1] He K, Gkioxari G, Dollár P, et al. Mask r-cnn[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2961-2969.