参考代码: mask rcnn benchmark
数据集来源:津南数字制造算法挑战赛【赛场二】初赛
这个代码不能直接运行,仅仅提供参考,本人也仅仅是接触检测不到一个礼拜,如果有什么疑问欢迎在讨论区交流。
数据集训练train_no_poly.json
的格式,类coco风格
import json
with open('../train_no_poly.json', 'r') as f:
data = json.load(f)
print(data.keys())
>>> dict_keys(['info', 'licenses', 'categories', 'images', 'annotations'])
print(data['info'])
>>> {'description': 'XRAY Instance Dataset ', 'url': '', 'version': '0.2.0', 'year': 2019, 'contributor': 'qianxiao', 'date_created': '2019-03-04 08:52:50.852455'}
print(data['licenses'])
>>> [{'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License', 'url': ''}]
print(data['categories'])
>>> [{'id': 1, 'name': '铁壳打火机', 'supercategory': 'restricted_obj'}, {'id': 2, 'name': '黑钉打火机', 'supercategory': 'restricted_obj'}, {'id': 3, 'name': '刀具', 'supercategory': 'restricted_obj'}, {'id': 4, 'name': '电源和电池', 'supercategory': 'restricted_obj'}, {'id': 5, 'name': '剪刀', 'supercategory': 'restricted_obj'}]
print(data['images'][0])
>>> {'coco_url': '', 'data_captured': '', 'file_name': '190119_184244_00166940.jpg', 'flickr_url': '', 'id': 0, 'height': 391, 'width': 680, 'license': 1}
print(data['annotations'][0]) # 注意,一个图像可能有多个bbox,json中把每个bbox分别存放在不同的字典中
>>> {'id': 1, 'image_id': 0, 'category_id': 3, 'iscrowd': 0, 'segmentation': [], 'area': [], 'bbox': [88, 253, 118, 42], 'minAreaRect': [[88, 298], [86, 256], [203, 249], [206, 291]]}
maskrcnn-benchmark/datasets/jinnan/jinnan2_round1_train_20190305
路径为maskrcnn-benchmark/maskrcnn_benchmark/config/paths_catalog.py
DATASETS
字典中添加你需要的路径,如"jinnan_train": {
"img_dir": "jinnan2_round1_train_20190305", # rgb格式文件路径
"ann_file": "jinnan2_round1_train_20190305/train_no_poly.json"
},
注意:自定义数据集的话,img_dir
和ann_file
会作为形参传到你自己创建的MyDataset
类里面
添加一个if else,把你创建的数据集相关内容放进去,如
elif "jinnan" in name: # name对应yaml文件传过来的数据集名字
data_dir = DatasetCatalog.DATA_DIR
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=os.path.join(data_dir, attrs["img_dir"]), # img_dir就是a步骤里面的内容
ann_file=os.path.join(data_dir, attrs["ann_file"]), # ann_file就是a步骤里面的内容
)
return dict(
factory="MyDataset", # 这个MyDataset对应
args=args,
)
上面参数解释(主要是MyDataset
):
这个MyDataset
就是你自己建的那个类,返回值是image, boxlist, idx
,具体实现参考git官网(很容易)
比如我实现好了MyDataset
类,然后这个py
文件取名为jinnan.py
然后放在maskrcnn-benchmark/maskrcnn_benchmark/data/datasets
路径下
接着配置那个目录里面的__init__.py
文件,第四行和all最后一个元素是自己加的
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .jinnan import MyDataset
all = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "MyDataset"]
__len__
,__getitem__
,get_img_info
,还有__init__
,其中__init__
会得到第一个步骤传来的attrs
,__init__
的一个参数参考:def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None)
不知参数是什么意思得去看maskrcnn-benchmark/maskrcnn_benchmark/data/build.py
主要是修改数据load部分
MODEL:
MASK_ON: False
DATASETS:
TRAIN: ("jinnan_train", "jinnan_val")
TEST: ("jinnan_test",)
上面三个值都是自己设的,其实有用的就jinnan_train
,当然首先重要的是要把MASK_ON
关闭。
注意只是参考,根据自己的不同需求返回image, boxlist, idx
就行
放置路径:maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/jinnan.py
from maskrcnn_benchmark.structures.bounding_box import BoxList
from PIL import Image
import os
import json
import torch
class MyDataset(object):
def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None):
# as you would do normally
self.transforms = transforms
self.train_path = root
with open(ann_file, 'r') as f:
self.data = json.load(f)
self.idxs = list(range(len(self.data['images']))) # 看要训练的图像有多少张,把id用个列表存储方便随机
self.bbox_label = {}
for anno in self.data['annotations']:
bbox = anno['bbox']
bbox[2] += bbox[0]
bbox[3] += bbox[1]
cate = anno['category_id']
image_id = anno['image_id']
if not image_id in self.bbox_label:
self.bbox_label[image_id] = [[bbox], [cate]]
else:
self.bbox_label[image_id][0].append(bbox)
self.bbox_label[image_id][1].append(cate)
def __getitem__(self, idx):
# load the image as a PIL Image
idx = self.idxs[idx % len(self.data['images'])]
# if idx not in self.bbox_label: # 210, 262, 690, 855 have no bbox
# idx += 1
path = self.data['images'][idx]['file_name']
folder = 'restricted' if idx < 981 else 'normal'
image = Image.open(os.path.join(self.train_path, folder, path)).convert('RGB')
# load the bounding boxes as a list of list of boxes
# in this case, for illustrative purposes, we use
# x1, y1, x2, y2 order.
# boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
boxes = self.bbox_label[idx][0]
category = self.bbox_label[idx][-1]
# and labels
labels = torch.tensor(category)
# create a BoxList from the boxes
boxlist = BoxList(boxes, image.size, mode="xyxy")
# add the labels to the boxlist
boxlist.add_field("labels", labels)
if self.transforms:
image, boxlist = self.transforms(image, boxlist)
# return the image, the boxlist and the idx in your dataset
return image, boxlist, idx
def __len__(self):
return len(self.data['images'])
def get_img_info(self, idx):
idx = self.idxs[idx % len(self.data['images'])]
height = self.data['images'][idx]['height']
width = self.data['images'][idx]['width']
# get img_height and img_width. This is used if
# we want to split the batches according to the aspect ratio
# of the image, as it can be more efficient than loading the
# image from disk
return {"height": height, "width": width}
transform maskrcnn-benchmark/maskrcnn_benchmark/data/build.py
在yaml里面把weight改成自己权重的路径,精确到文件
_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81改成6
我这里把category_id设成了image_id
copy_if failed to synchronize: device-side assert triggered
https://github.com/facebookresearch/maskrcnn-benchmark/issues/450