mask rcnn bencmark pytorch自定义数据集的方法

前言

参考代码: mask rcnn benchmark

数据集来源:津南数字制造算法挑战赛【赛场二】初赛

这个代码不能直接运行,仅仅提供参考,本人也仅仅是接触检测不到一个礼拜,如果有什么疑问欢迎在讨论区交流。

1、数据解读

数据集训练train_no_poly.json的格式,类coco风格

import json
with open('../train_no_poly.json', 'r') as f:
    data = json.load(f)

print(data.keys())
>>> dict_keys(['info', 'licenses', 'categories', 'images', 'annotations'])

print(data['info'])
>>> {'description': 'XRAY Instance Dataset ', 'url': '', 'version': '0.2.0', 'year': 2019, 'contributor': 'qianxiao', 'date_created': '2019-03-04 08:52:50.852455'}

print(data['licenses'])
>>> [{'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License', 'url': ''}]

print(data['categories'])
>>> [{'id': 1, 'name': '铁壳打火机', 'supercategory': 'restricted_obj'}, {'id': 2, 'name': '黑钉打火机', 'supercategory': 'restricted_obj'}, {'id': 3, 'name': '刀具', 'supercategory': 'restricted_obj'}, {'id': 4, 'name': '电源和电池', 'supercategory': 'restricted_obj'}, {'id': 5, 'name': '剪刀', 'supercategory': 'restricted_obj'}]

print(data['images'][0])
>>> {'coco_url': '', 'data_captured': '', 'file_name': '190119_184244_00166940.jpg', 'flickr_url': '', 'id': 0, 'height': 391, 'width': 680, 'license': 1}

print(data['annotations'][0])  # 注意,一个图像可能有多个bbox,json中把每个bbox分别存放在不同的字典中
>>> {'id': 1, 'image_id': 0, 'category_id': 3, 'iscrowd': 0, 'segmentation': [], 'area': [], 'bbox': [88, 253, 118, 42], 'minAreaRect': [[88, 298], [86, 256], [203, 249], [206, 291]]}

2、拷贝数据集到根目录的datasets下(和demo同级目录)如

maskrcnn-benchmark/datasets/jinnan/jinnan2_round1_train_20190305

3、修改paths_catalog.py

路径为maskrcnn-benchmark/maskrcnn_benchmark/config/paths_catalog.py

a、在paths_catalog中的DATASETS字典中添加你需要的路径,如

"jinnan_train": {
"img_dir": "jinnan2_round1_train_20190305",  # rgb格式文件路径
"ann_file": "jinnan2_round1_train_20190305/train_no_poly.json"
},

注意:自定义数据集的话,img_dirann_file会作为形参传到你自己创建的MyDataset类里面

b、修改paths_catalog中部静态函数get(name)方法

添加一个if else,把你创建的数据集相关内容放进去,如

elif "jinnan" in name:  # name对应yaml文件传过来的数据集名字
    data_dir = DatasetCatalog.DATA_DIR
    attrs = DatasetCatalog.DATASETS[name]
    args = dict(
        root=os.path.join(data_dir, attrs["img_dir"]),  # img_dir就是a步骤里面的内容
        ann_file=os.path.join(data_dir, attrs["ann_file"]),  # ann_file就是a步骤里面的内容
    )
    return dict(
        factory="MyDataset",  # 这个MyDataset对应
        args=args,
    )

上面参数解释(主要是MyDataset):

  1. 这个MyDataset就是你自己建的那个类,返回值是image, boxlist, idx,具体实现参考git官网(很容易)

  2. 比如我实现好了MyDataset类,然后这个py文件取名为jinnan.py

  3. 然后放在maskrcnn-benchmark/maskrcnn_benchmark/data/datasets路径下

  4. 接着配置那个目录里面的__init__.py文件,第四行和all最后一个元素是自己加的

from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .jinnan import MyDataset

all = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "MyDataset"]
  1. 注意,实现MyDataset要实现__len____getitem__get_img_info,还有__init__,其中__init__会得到第一个步骤传来的attrs__init__的一个参数参考:
def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None)

不知参数是什么意思得去看maskrcnn-benchmark/maskrcnn_benchmark/data/build.py

4、修改yaml文件

主要是修改数据load部分

MODEL:
  MASK_ON: False
DATASETS:
  TRAIN: ("jinnan_train", "jinnan_val")
  TEST: ("jinnan_test",)

上面三个值都是自己设的,其实有用的就jinnan_train,当然首先重要的是要把MASK_ON关闭。

5、 我自己写的数据加载的凌乱的参考

注意只是参考,根据自己的不同需求返回image, boxlist, idx就行

放置路径:maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/jinnan.py

from maskrcnn_benchmark.structures.bounding_box import BoxList
from PIL import Image
import os
import json
import torch

class MyDataset(object):
    def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None):
        # as you would do normally

        self.transforms = transforms

        self.train_path = root
        with open(ann_file, 'r') as f:
            self.data = json.load(f)

        self.idxs = list(range(len(self.data['images'])))  # 看要训练的图像有多少张,把id用个列表存储方便随机
        self.bbox_label = {}
        for anno in self.data['annotations']:
            bbox = anno['bbox']
            bbox[2] += bbox[0]
            bbox[3] += bbox[1]
            cate = anno['category_id']
            image_id = anno['image_id']
            if not image_id in self.bbox_label:
                self.bbox_label[image_id] = [[bbox], [cate]]
            else:
                self.bbox_label[image_id][0].append(bbox)
                self.bbox_label[image_id][1].append(cate)

    def __getitem__(self, idx):
        # load the image as a PIL Image
        idx = self.idxs[idx % len(self.data['images'])]
        # if idx not in self.bbox_label:  # 210, 262, 690, 855 have no bbox
        #    idx += 1
        path = self.data['images'][idx]['file_name']

        folder = 'restricted' if idx < 981 else 'normal'

        image = Image.open(os.path.join(self.train_path, folder, path)).convert('RGB')
        # load the bounding boxes as a list of list of boxes
        # in this case, for illustrative purposes, we use
        # x1, y1, x2, y2 order.
        # boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
        boxes = self.bbox_label[idx][0]
        category = self.bbox_label[idx][-1]

        # and labels
        labels = torch.tensor(category)

        # create a BoxList from the boxes
        boxlist = BoxList(boxes, image.size, mode="xyxy")
        # add the labels to the boxlist
        boxlist.add_field("labels", labels)

        if self.transforms:
            image, boxlist = self.transforms(image, boxlist)

        # return the image, the boxlist and the idx in your dataset
        return image, boxlist, idx
    def __len__(self):
        return len(self.data['images'])

    def get_img_info(self, idx):
        idx = self.idxs[idx % len(self.data['images'])]
        height = self.data['images'][idx]['height']
        width = self.data['images'][idx]['width']
        # get img_height and img_width. This is used if
        # we want to split the batches according to the aspect ratio
        # of the image, as it can be more efficient than loading the
        # image from disk
        return {"height": height, "width": width}

其他

transform maskrcnn-benchmark/maskrcnn_benchmark/data/build.py

在yaml里面把weight改成自己权重的路径,精确到文件

_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81改成6

我这里把category_id设成了image_id

copy_if failed to synchronize: device-side assert triggered

https://github.com/facebookresearch/maskrcnn-benchmark/issues/450

你可能感兴趣的:(deeplearning)