不了解coco数据集形式的可以先去查找了解一下
下面的detecton2就是你conda环境中的安装的detectron2,照着下面找一下:
/home/anaconda3/envs/你的conda环境名/lib/python3.7/site-packages/detectron2
从detectron2/data/datasets中修改builtin.py
# builtin中添加的内容
# 在_PREDEFINED_SPLITS_COCO["coco"]中添加自己数据集的路径
_PREDEFINED_SPLITS_COCO["coco"] = {
# 自己的数据集:
# "coco_handwritten_train"为自定义名称,后面会用到;后面的两个路径第一个是图片目录,第二个是对应json文件路径
# 训练集
"coco_handwritten_train": ("/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet/datasets/handwritten_chinese_stroke_2021/train2021",
"/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet/datasets/handwritten_chinese_stroke_2021/annotations/instances_train2021.json"),
# 验证集
"coco_handwritten_val": ("/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet/datasets/handwritten_chinese_stroke_2021/val2021",
"/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet/datasets/handwritten_chinese_stroke_2021/annotations/instances_val2021.json"),
}
...
# 更改if __name__.endswith(".builtin"):下的_root
if __name__.endswith(".builtin"):
# Assume pre-defined datasets live in `./datasets`.
# 这里后面更改为你自己的AdelaiDet的根目录
_root = os.getenv("DETECTRON2_DATASETS", "/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet")
#_root = os.getenv("/mnt/big_disk/big_disk_1/zhuzhibo/download/AdelaiDet", "datasets")
从detectron2/data/datasets中修改builtin_meta.py
这里注意类别的定义顺序要与你转coco数据集时的一样(我没试不一样会怎么样),颜色填不重复的就行,isthing为1,id从1开始,name为这个类别的名称
# 将原来的COCO_CATEGORIES注释掉,换为新的
COCO_CATEGORIES = [
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "wangou"},
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "na"},
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "ti"},
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "pie"},
{"color": [0, 0, 230], "isthing": 1, "id": 5, "name": "piezhe"},
{"color": [106, 0, 228], "isthing": 1, "id": 6, "name": "piedian"},
{"color": [0, 60, 100], "isthing": 1, "id": 7, "name": "xiegouhuowogou"},
{"color": [0, 80, 100], "isthing": 1, "id": 8, "name": "heng"},
{"color": [0, 0, 70], "isthing": 1, "id": 9, "name": "hengzhe"},
{"color": [0, 0, 192], "isthing": 1, "id": 10, "name": "hengzhezhehuohengzhewan"},
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "hengzhezhezhe"},
{"color": [220, 220, 0], "isthing": 1, "id": 12, "name": "hengzhezhezhegouhuohengpiewangou"},
{"color": [175, 116, 175], "isthing": 1, "id": 13, "name": "hengzhezhepie"},
{"color": [250, 0, 30], "isthing": 1, "id": 14, "name": "hengzheti"},
{"color": [165, 42, 42], "isthing": 1, "id": 15, "name": "hengzhegou"},
{"color": [255, 77, 255], "isthing": 1, "id": 16, "name": "hengpiehuohenggou"},
{"color": [0, 226, 252], "isthing": 1, "id": 17, "name": "hengxiegou"},
{"color": [182, 182, 255], "isthing": 1, "id": 18, "name": "dian"},
{"color": [0, 82, 0], "isthing": 1, "id": 19, "name": "shu"},
{"color": [120, 166, 157], "isthing": 1, "id": 20, "name": "shuwan"},
{"color": [110, 76, 0], "isthing": 1, "id": 21, "name": "shuwangou"},
{"color": [174, 57, 255], "isthing": 1, "id": 22, "name": "shuzhezhegou"},
{"color": [199, 100, 0], "isthing": 1, "id": 23, "name": "shuzhepiehuoshuzhezhe"},
{"color": [72, 0, 118], "isthing": 1, "id": 24, "name": "shuti"},
{"color": [169, 164, 131], "isthing": 1, "id": 25, "name": "shugou"},
]
# 更改_getcoco_instances_mata()如下
# assert len(thing_ids) == 25, len(thing_ids) 我这里是25类,改成你自己的类别数(不算背景类)
def _get_coco_instances_meta():
thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
assert len(thing_ids) == 25, len(thing_ids)
# Mapping from the incontiguous COCO category id to an id in [0, 24]
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
ret = {
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
"thing_classes": thing_classes,
"thing_colors": thing_colors,
}
return ret
# 更改_get_coco_panoptic_separated_meat()中以下内容
assert len(stuff_ids) == 0, len(stuff_ids)
#assert len(stuff_ids) == 53, len(stuff_ids) 原来的语句
def get_parser():
parser = argparse.ArgumentParser(description="Keep only model in ckpt")
parser.add_argument(
"--dataset-name",
default="handwritten_chinese_stroke_2021",
help="dataset to generate",
)
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
dataset_dir = os.path.join(os.path.dirname(__file__), args.dataset_name)
if args.dataset_name == "handwritten_chinese_stroke_2021":
thing_id_to_contiguous_id = _get_coco_instances_meta()["thing_dataset_id_to_contiguous_id"]
split_name = 'train2021'
annotation_name = "annotations/instances_{}.json"
...
MODEL:
META_ARCHITECTURE: "BlendMask"
MASK_ON: True
BACKBONE:
NAME: "build_fcos_resnet_fpn_backbone"
RESNETS:
OUT_FEATURES: ["res3", "res4", "res5"]
FPN:
IN_FEATURES: ["res3", "res4", "res5"]
PROPOSAL_GENERATOR:
NAME: "FCOS"
BASIS_MODULE:
LOSS_ON: True
PANOPTIC_FPN:
COMBINE:
ENABLED: False
FCOS:
THRESH_WITH_CTR: True
USE_SCALE: False
DATALOADER: 8
DATASETS: # 这里写自己之前第一步注册的名字
TRAIN: ("coco_handwritten_train",) # 训练
TEST: ("coco_handwritten_val",) # 验证
SOLVER:
IMS_PER_BATCH: 32
BASE_LR: 0.001 # Note that RetinaNet uses a different default learning rate
STEPS: (200, 400)
MAX_ITER: 600
INPUT:
MIN_SIZE_TRAIN: (120, 120) # 训练最小图片尺寸,按照你的需求修改
MIN_SIZE_TEST: 120 # 验证最小图片尺寸,按照你的需求修改
_BASE_: "Base-BlendMask.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
BACKBONE:
NAME: "build_fcos_resnet_fpn_backbone"
RESNETS:
DEPTH: 101
FCOS:
NUM_CLASSES: 25 # 更改为自定义数据集类别
# INPUT的内容根据自己需要修改
INPUT:
FORMAT: BGR
MASK_FORMAT: polygon
MAX_SIZE_TEST: 144
MAX_SIZE_TRAIN: 144
MIN_SIZE_TEST: 112
MIN_SIZE_TRAIN: [ 112, 120 ]
SOLVER:
STEPS: [2500, 29999]
MAX_ITER: 30000 # default 270000
WARMUP_ITERS: 100 # default 1000
CHECKPOINT_PERIOD: 2500 # default 5000
IMS_PER_BATCH: 64 # default 16
BASE_LR: 0.025 # default 0.02
TEST:
EVAL_PERIOD: 2500 # default 0, means no evaluation will be performed
VIS_PERIOD: 2500 # default 0, means no visualization will be performed
DATASETS: # 可以覆盖Base-BlendMask.yaml中的配置
TRAIN: ("coco_handwritten_train",)
TEST: ("coco_handwritten_val",)
OUTPUT_DIR: "training_dir/blendmask/hw_R_101_3x_fpn" # 训练输出目录
if self.basis_loss_on and self.is_train:
# load basis supervisions
"""
if self.ann_set == "coco":
basis_sem_path = (
dataset_dict["file_name"]
.replace("train2017", "thing_train2017")
.replace("image/train", "thing_train")
)
"""
# 更改上面内容为下面
if self.ann_set == "coco":
basis_sem_path = (
dataset_dict["file_name"]
.replace("train2021", "thing_train2021")
)
def setup(args):
...
# 更改使用的yaml配置文件
args.config_file = '/home/zhuzhibo/AdelaiDet/configs/BlendMask/R_101_3x.yaml'
...