深度学习7 Transformer系列实例分割Mask2Former

文章目录

  • 前言
  • 正文
    • 开源地址
    • 安装
    • 验证(下载对应模型)
    • 训练
      • 注册自定义数据集
      • 指定训练数据集
      • 训练
      • 状况处理
        • 1)显存不够
        • 2)
        • 3)数据集分类数与模型不一致

前言

正文

开源地址

https://github.com/facebookresearch/Mask2Former

安装

参考 https://github.com/facebookresearch/Mask2Former/blob/main/INSTALL.md

conda create --name mask2former python=3.8 -y
conda activate mask2former
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# 这里给cudatoolkit 换源
conda install cudatoolkit -c anaconda
pip install opencv-python

### 安装detectron2 API
git clone [email protected]:facebookresearch/detectron2.git
cd detectron2
pip install -e .
pip install git+https://github.com/cocodataset/panopticapi.git
pip install git+https://github.com/mcordts/cityscapesScripts.git

cd ..
git clone [email protected]:facebookresearch/Mask2Former.git
cd Mask2Former
pip install -r requirements.txt
cd mask2former/modeling/pixel_decoder/ops
sh make.sh

验证(下载对应模型)

conda activate mask2former
cd Mask2Former/demo
python demo.py --config-file ../configs/coco/panoptic-segmentation/maskformer2_R50_bs16_50ep.yaml   --input 1.jpg --output ./output
python demo.py --config-file ../configs/coco/instance-segmentation/swin/maskformer2_swin_tiny_bs16_50ep.yaml --input 2.jpg --output ./tiny --opts MODEL.WEIGHTS "../weights/swin_tiny_patch4_window7_224.pkl"
python demo.py --config-file ../configs/coco/instance-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml --input 2.jpg --output ./large --opts MODEL.WEIGHTS "../weights/swin_large_patch4_window12_384_22k.pkl"   

训练

Mask2Former的训练推理都基于detectron2 API, 训练前需要构建自己的数据集,并向detectron2 API 注册

注册自定义数据集

具体说明:
https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html

注册实例:
https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5#scrollTo=PIbAM2pv-urF

from detectron2.structures import BoxMode

def get_balloon_dicts(img_dir):
    json_file = os.path.join(img_dir, "via_region_data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns.values()):
        record = {}
        
        filename = os.path.join(img_dir, v["filename"])
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
      
        annos = v["regions"]
        objs = []
        for _, anno in annos.items():
            assert not anno["region_attributes"]
            anno = anno["shape_attributes"]
            px = anno["all_points_x"]
            py = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": 0,
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

for d in ["train", "val"]:
    DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d))
    MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])
balloon_metadata = MetadataCatalog.get("balloon_train")

COCO格式数据集,请直接调用API注册

from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train", {}, "json_annotation_train.json", "path/to/image/dir")
register_coco_instances("my_dataset_val", {}, "json_annotation_val.json", "path/to/image/dir")

指定训练数据集

BASE: …/maskformer2_R50_bs16_50ep.yaml
DATASETS:
TRAIN: (“my_dataset_train”,)
TEST: (“my_dataset_val”,)
MODEL:
BACKBONE:
NAME: “D2SwinTransformer”
SWIN:
EMBED_DIM: 192
DEPTHS: [2, 2, 18, 2]
NUM_HEADS: [6, 12, 24, 48]
WINDOW_SIZE: 12
APE: False
DROP_PATH_RATE: 0.3
PATCH_NORM: True
PRETRAIN_IMG_SIZE: 384
WEIGHTS: “swin_large_patch4_window12_384_22k.pkl”
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_FORMER:
NUM_OBJECT_QUERIES: 200
SOLVER:
STEPS: (655556, 710184)
MAX_ITER: 737500

训练

cd Mask2Former
python train_net.py --num-gpus 1 --config-file configs/coco/instance-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml  MODEL.WEIGHTS "weights/swin_large_patch4_window12_384_22k.pkl"

状况处理

1)显存不够

RuntimeError: CUDA out of memory. Tried to allocate 410.00 MiB (GPU 0; 10.91 GiB total capacity; 4.24 GiB already allocated; 151.44 MiB free; 4.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
【解决方案】采用更小的模型和更小的batch_size, 在配置文件中修改,其配置文件层层依赖,注意每一层设置的参数

SOLVER:
  IMS_PER_BATCH: 1

2)

File “/dataset/projects/Mask2Former/mask2former/modeling/matcher.py”, line 141, in memory_efficient_forward
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Global alloc not supported yet
【解决方案】 参考 https://github.com/facebookresearch/Mask2Former/issues/4
将batch_dice_loss_jit 替换为batch_dice_loss

# cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) 
cost_dice = batch_dice_loss(out_mask, tgt_mask) 

3)数据集分类数与模型不一致

修改配置文件即可

_BASE_: ../maskformer2_R50_bs16_50ep.yaml

MODEL:
  RETINANET:
    NUM_CLASSES: 2
  ROI_HEADS:
    NUM_CLASSES: 2
  SEM_SEG_HEAD:
    NUM_CLASSES: 2
  BACKBONE:
    NAME: "D2SwinTransformer"
  SWIN:
    EMBED_DIM: 96
    DEPTHS: [2, 2, 18, 2]
    NUM_HEADS: [3, 6, 12, 24]
    WINDOW_SIZE: 7
    APE: False
    DROP_PATH_RATE: 0.3
    PATCH_NORM: True
  WEIGHTS: "swin_small_patch4_window7_224.pkl"
  PIXEL_MEAN: [123.675, 116.280, 103.530]
  PIXEL_STD: [58.395, 57.120, 57.375]

DATASETS:
  TRAIN: ("my_dataset_train",)
  TEST: ("my_dataset_val",)
  
SOLVER:
  IMS_PER_BATCH: 1
  
DATALOADER:
  NUM_WORKERS: 1

OUTPUT_DIR: ./output/small_wf_alarm

你可能感兴趣的:(AI,深度学习,transformer,python)