pytorch MaskRCNN-Benchmark框架训练自己的数据集,类别数量不同

测试环境:
Ubuntu16.04
conda python3.6
cuda9.0

跑通demo 可以对照git步骤即可.
https://github.com/facebookresearch/maskrcnn-benchmark

网上针对自己数据集进行训练的过程并不完整,自己遇到了一些问题,现在解决记录.

  • Q:类别数量不同,如何修改

step0: 下载权重

链接: https://pan.baidu.com/s/12eLhn6_LLY0RudwOVJLs6A 提取码: 7agn

step1: 生成自己的数据集

使用的是labelme标注,并转换成coco格式,之前在keras-maskrcnn下制作的方法同样适用
https://blog.csdn.net/qq_35608277/article/details/79873456

step2: 数据放在项目中

在项目下新建datasets及其子文件夹,如图所示.同时给出了自己的整体文件结构.
pytorch MaskRCNN-Benchmark框架训练自己的数据集,类别数量不同_第1张图片

step3:修改cfg文件

修改cfg文件和config文件夹中的参数和路径设置.
我有5个类别.cfg如上,defaults.py如下
增加类别修改信息,数据集路径,训练参数
paths_catalog.py保证跟自己的data位置一样.如我使用的

    DATA_DIR = "datasets"
    DATASETS = {
        "coco_2017_train": {
            "img_dir": "coco/train2017",
            "ann_file": "coco/annotations/instances_train2017.json"
        },
        "coco_2017_val": {
            "img_dir": "coco/val2017",
            "ann_file": "coco/annotations/instances_val2017.json"
        },

pytorch MaskRCNN-Benchmark框架训练自己的数据集,类别数量不同_第2张图片

step4:修改权重格式

因为权重是对应的coco的81类,会报错.
所以对应不同类先删掉最后的输出分支.
git上介绍使用脚本:trim_detectron_model.py
但是对于此pth,不适用.
使用如下脚本,根据报错,把对应的weight删掉.保存为mymodel.pth作为预训练权重.

_d=torch.load(path)
newdict=_d
def removekey(d, listofkeys):
    r=d
    for key in listofkeys:
        del r[key]
    return r

newdict['model'] = removekey(_d['model'], ['module.roi_heads.box.predictor.cls_score.bias','module.roi_heads.box.predictor.cls_score.weight','module.roi_heads.box.predictor.bbox_pred.bias','module.roi_heads.box.predictor.bbox_pred.weight','module.roi_heads.mask.predictor.mask_fcn_logits.weight','module.roi_heads.mask.predictor.mask_fcn_logits.bias'])
torch.save(newdict, 'mymodel.pth')

step5:开始训练

在项目总目录下使用命令行:

python tools/train_net.py --config-file configs/my_cfg/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml

训练显示

2019-03-28 11:33:42,969 maskrcnn_benchmark.utils.checkpoint INFO: Loading checkpoint from /home/dong/MASK_RCNN/maskrcnn-benchmark-master/mask/weights/mymodel.pth
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn1.bias                   loaded from backbone.body.layer1.0.bn1.bias                   of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn1.running_mean           loaded from backbone.body.layer1.0.bn1.running_mean           of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn1.running_var            loaded from backbone.body.layer1.0.bn1.running_var            of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn1.weight                 loaded from backbone.body.layer1.0.bn1.weight                 of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn2.bias                   loaded from backbone.body.layer1.0.bn2.bias                   of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn2.running_mean           loaded from backbone.body.layer1.0.bn2.running_mean           of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn2.running_var            loaded from backbone.body.layer1.0.bn2.running_var            of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn2.weight                 loaded from backbone.body.layer1.0.bn2.weight                 of shape (64,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn3.bias                   loaded from backbone.body.layer1.0.bn3.bias                   of shape (256,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn3.running_mean           loaded from backbone.body.layer1.0.bn3.running_mean           of shape (256,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn3.running_var            loaded from backbone.body.layer1.0.bn3.running_var            of shape (256,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.bn3.weight                 loaded from backbone.body.layer1.0.bn3.weight                 of shape (256,)
2019-03-28 11:33:43,041 maskrcnn_benchmark.utils.model_serialization INFO: backbone.body.layer1.0.conv1.weight               loaded from backbone.body.layer1.0.conv1.weight               of shape (64, 64, 1, 1)
2019-03-28 11:33:43,042 maskrcnn_benchmark.utils.mode

step6: 在自己数据集上测试

首先新建一个测试cfg
config_file="…/configs/my_cfg/test_mask_rcnn_R_50_FPN.yaml"
weight路径改成训练过的权重
新建自己的data数据类:
直接在predictor中复制一份coco类
改成Mydata
然后只改

CATEGORIES = [
    "__background",
    "",
    ...
]

保证最后出来的是自己的类名

测试脚本:

# -*- coding: utf-8 -*-

from maskrcnn_benchmark.config import cfg
from demo.predictor import Mydata
import cv2

config_file="../configs/my_cfg/test_mask_rcnn_R_50_FPN.yaml"
cfg.merge_from_file(config_file)

# prepare object that handles inference plus adds predictions on top of image
my_detection = Mydata(
    cfg,
    confidence_threshold=0.1,

    min_image_size=800,
)
res=cv2.imread('127.jpg')

import time
start = time.clock()
composite = my_detection.run_on_opencv_image(res)
end = time.clock()  
print(end - start)
cv2.imshow("detections", composite)
cv2.waitKey(0)
cv2.destroyWindow('detections')

step7 评估效果

https://github.com/Cartucho/mAP

code 分析

ref
https://www.cnblogs.com/ranjiewen/p/10001590.html

你可能感兴趣的:(torch)