利用MMSegmentation微调Mask2Former模型

前言

  • 本文介绍了专用于语义分隔模型的pythonmmsegmentationgithub项目地址,运行环境为Kaggle notebookGPUP100
  • 针对环境配置、预训练模型推理、在西瓜数据集上微调新sota模型mask2former模型,数据说明
  • 由于西瓜数据集较小,我们最后在组织病理切片肾小球数据集上微调了mask2former模型,数据说明
  • 该教程有部分参考github项目MMSegmentation_Tutorials,项目地址

环境配置

  • 跑通代码需要openmimmmsegmentationmmenginemmdetectionmmcv环境,mmcv环境在kaggle配置比较麻烦,需要预配置包,这里我将所有预配置包都打包好了,放到了数据集frozen-packages-mmdetection中,详情页
import IPython.display as display
!pip install -U openmim

!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -v -e .

!pip install "mmdet>=3.0.0rc4"

!pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl

!pip install wandb
display.clear_output()
  • 实测运行上述代码,在kaggle中可以达到运行项目需求,无报错(2023年7月13日)。
  • 导入常用基础包
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Config

import matplotlib.pyplot as plt
%matplotlib inline

from mmseg.datasets import cityscapes
from mmseg.utils import register_all_modules
register_all_modules()

from mmseg.datasets import CityscapesDataset
from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import init_model, inference_model, show_result_pyplot

# 忽略警告
import warnings
warnings.filterwarnings('ignore')

display.clear_output()
  • 创建文件夹,用于放置数据集、模型预训练权重和模型推理输出
# 创建 checkpoint 文件夹,用于存放预训练模型权重文件
os.mkdir('checkpoint')

# 创建 outputs 文件夹,用于存放预测结果
os.mkdir('outputs')

# 创建 data 文件夹,用于存放图片和视频素材
os.mkdir('data')
  • 分别下载pspnet、segformer、mask2former在cityscapes上的预训练权重,并保存在checkpoint文件夹中
# 从Model Zoo预训练模型,下载并保存在 checkpoint 文件夹中
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth -P checkpoint
display.clear_output()
  • 下载一些测试模型用的图片以及视频,并存放到data文件夹中。
# 伦敦街景图片
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data

# 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
!wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data

# 街拍视频,2022年3月30日
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data
display.clear_output()

图片推理

命令行推理

  • 使用命令行对图片进行推理,并使用PIL对结果进行可视化
  • 分别使用了pspnet模型和segformer模型进行推理
# pspnet模型
!python demo/image_demo.py \
        data/street_uk.jpeg \
        configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
        checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
        --out-file outputs/B1_uk_pspnet.jpg \
        --device cuda:0 \
        --opacity 0.5

display.clear_output()
Image.open('outputs/B1_uk_pspnet.jpg')

# segformer模型
!python demo/image_demo.py \
        data/street_uk.jpeg \
        configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
        checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
        --out-file outputs/B1_uk_segformer.jpg \
        --device cuda:0 \
        --opacity 0.5
display.clear_output()
Image.open('outputs/B1_uk_segformer.jpg')

  • 可以看到其实segformer的效果比pspnet模型效果要好,基本上能将不同物体分割开。

API推理

  • 使用mmsegmentation的Python API进行图片推理
  • 使用mask2former模型推理,并利用matplotlib对结果进行可视化
img_path = 'data/street_uk.jpeg'
img_pil = Image.open(img_path)
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

model = init_model(config_file, checkpoint_file, device='cuda:0')

if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)

result = inference_model(model, img_path)
pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()

display.clear_output()
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(14, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.55) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('outputs/B2-1.jpg')
plt.show()

利用MMSegmentation微调Mask2Former模型_第1张图片

  • mask2former作为sota模型,效果确实非常棒!

视频推理

命令行推理

  • 不推荐,速度很慢
!python demo/video_demo.py \
        data/street_20220330_174028.mp4 \
        configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
        checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
        --device cuda:0 \
        --output-file outputs/B3_video.mp4 \
        --opacity 0.5

API推理

  • mask2former模型使用API对视频进行推理
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

model = init_model(config_file, checkpoint_file, device='cuda:0')

if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)

display.clear_output()

input_video = 'data/street_20220330_174028.mp4'

temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))

# 获取 Cityscapes 街景数据集 类别名和调色板
classes = cityscapes.CityscapesDataset.METAINFO['classes']
palette = cityscapes.CityscapesDataset.METAINFO['palette']

def pridict_single_frame(img, opacity=0.2):

    result = inference_model(model, img)

    # 将分割图按调色板染色
    seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
    seg_img = Image.fromarray(seg_map).convert('P')
    seg_img.putpalette(np.array(palette, dtype=np.uint8))

    show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacity

    return show_img

# 读入待预测视频
imgs = mmcv.VideoReader(input_video)

prog_bar = mmengine.ProgressBar(len(imgs))

# 对视频逐帧处理
for frame_id, img in enumerate(imgs):

    ## 处理单帧画面
    show_img = pridict_single_frame(img, opacity=0.15)
    temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' # 保存语义分割预测结果图像至临时文件夹
    cv2.imwrite(temp_path, show_img)

    prog_bar.update() # 更新进度条

# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)

小样本数据集微调mask2former

  • 在西瓜语义分隔数据集上对模型进行微调

下载数据集

!rm -rf Watermelon87_Semantic_Seg_Mask.zip Watermelon87_Semantic_Seg_Mask

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip

!unzip Watermelon87_Semantic_Seg_Mask.zip >> /dev/null # 解压

!rm -rf Watermelon87_Semantic_Seg_Mask.zip # 删除压缩包

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/watermelon_test1.jpg -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_2.mp4 -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_3.mov -P data

# 删除系统自动生成的多余文件
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'

# 删除多余文件
!for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
!for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
!for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done

# 验证多余文件已删除
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'

display.clear_output()

可视化探索语义分割数据集

  • 可视化语义信息
# 指定单张图像路径
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'

img = cv2.imread(img_path)
mask = cv2.imread(mask_path)

# 可视化原图叠加
plt.figure(figsize=(8, 8))
plt.imshow(img[:,:,::-1])
plt.imshow(mask[:,:,0], alpha=0.6) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.show()

利用MMSegmentation微调Mask2Former模型_第2张图片

定义Dataset和Pipeline

  • Dataset部分,可以设定数值对应的具体类别,以及不同类别的标注颜色。图像格式,是否忽略类别0
  • Pipeline部分,可以设定训练、验证的数据处理步骤。以及规定图像裁剪尺寸
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
    # 类别和对应的 RGB配色
    METAINFO = {
        'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],
        'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
"""

with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:
    f.write(custom_dataset)
  • custom_dataset加入__init__.py文件
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
                         BioMedical3DRandomCrop, BioMedical3DRandomFlip,
                         BioMedicalGaussianBlur, BioMedicalGaussianNoise,
                         BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
                         LoadBiomedicalAnnotation, LoadBiomedicalData,
                         LoadBiomedicalImageFromFile, LoadImageFromNDArray,
                         PackSegInputs, PhotoMetricDistortion, RandomCrop,
                         RandomCutOut, RandomMosaic, RandomRotate,
                         RandomRotFlip, Rerange, ResizeShortestEdge,
                         ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
__all__ = [
    'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
    'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
    'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
    'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
    'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
    'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
    'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
    'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
    'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
    'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
    'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
    'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
    'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
    'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
    'SynapseDataset', 'MyCustomDataset'
]

"""

with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:
    f.write(custom_init)
  • 定义数据集预处理通道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)

# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/train', seg_map_path='ann_dir/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/val', seg_map_path='ann_dir/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
"""

with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:
    f.write(custom_pipeline)

修改配置文件

  • 主要修改类别个数、预训练权重路径、初始化图片尺寸(一般为128的整数倍)、batch_size、缩放学习率(修改的比例是 base_lr_default * (your_bs / default_bs))、更改学习率衰减策略
  • 关于学习率:主要修改optimizer中的lr,不用修改optim_wrapper
  • 冻结模型的骨干网络,对mask2former来说可以加快训练
cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
# 类别个数
NUM_CLASS = 6
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size

# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4


# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloader


cfg.optimizer.lr = cfg.optimizer.lr / 8

# 结果保存目录
cfg.work_dir = './work_dirs'

cfg.train_cfg.max_iters = 4000 # 训练迭代次数
cfg.train_cfg.val_interval = 50 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 50 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

cfg.param_scheduler[0].end = cfg.train_cfg.max_iters
# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件
cfg.dump('custom_mask2former.py')
  • 开始训练
!python tools/train.py custom_mask2former.py
  • 选取最优模型,测试模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+------------+-------+-------+-------+--------+-----------+--------+
|   Class    |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+------------+-------+-------+-------+--------+-----------+--------+
| background | 98.55 | 99.12 | 99.27 | 99.27  |   99.42   | 99.12  |
|    red     | 96.54 | 98.83 | 98.24 | 98.24  |   97.65   | 98.83  |
|   green    | 94.37 | 96.08 |  97.1 |  97.1  |   98.14   | 96.08  |
|   white    | 85.96 | 92.67 | 92.45 | 92.45  |   92.24   | 92.67  |
| seed-black | 81.98 | 90.87 |  90.1 |  90.1  |   89.34   | 90.87  |
| seed-white | 65.57 | 69.98 | 79.21 | 79.21  |   91.24   | 69.98  |
+------------+-------+-------+-------+--------+-----------+--------+

可视化训练指标

利用MMSegmentation微调Mask2Former模型_第3张图片

肾小球数据集微调模型

  • 在单类别数据集(组织病理切片肾小球)上微调mask2former模型
  • 首先清空工作目录、data文件夹和outputs文件
# 清空工作目录
!rm -r work_dirs/*
# 清空data文件夹
!rm -r data/*
# 清空outputs文件夹
!rm -r outputs/*

可视化探索语义分割数据集

# 指定图像和标注路径
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'

mask = cv2.imread('/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png')
# 查看类别
np.unique(mask)
  • 输出
array([0, 1], dtype=uint8)
  • 可视化语义分割信息
# n行n列可视化
n = 5

# 标注区域透明度,透明度越小,越接近原图
opacity = 0.65

fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))

for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):
    
    # 载入图像和标注
    img_path = os.path.join(PATH_IMAGE, file_name)
    mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    
    # 可视化
    axes[i//n, i%n].imshow(img[:,:,::-1])
    axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)
    axes[i//n, i%n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=20)
plt.tight_layout()
plt.savefig('outputs/C2-1.jpg')
plt.show()

分割训练集与测试集

  • 新建各类训练、验证文件夹
# 新建图片训练、验证文件夹
!mkdir -p data/images/train
!mkdir -p data/images/val

# 新建mask训练、验证文件夹
!mkdir -p data/masks/train
!mkdir -p data/masks/val
  • 随机打乱数据,并按照90%训练集、10%测试集分割
def copy_file(og_images, og_masks, tr_images, tr_masks, thor):
    # 获取源文件夹中的所有文件名
    file_names = os.listdir(og_images)
    
    # 随机打乱文件名列表
    random.shuffle(file_names)
    
    # 计算分割点
    split_index = int(thor * len(file_names))
    
    # 复制训练集文件
    for file_name in file_names[:split_index]:
        og_image = os.path.join(og_images, file_name)
        og_mask = os.path.join(og_masks, file_name)
        tr_image = os.path.join(tr_images, 'train', file_name)
        tr_mask = os.path.join(tr_masks, 'train', file_name)
        shutil.copyfile(og_image, tr_image)
        shutil.copyfile(og_mask, tr_mask)

    # 复制验证集文件
    for file_name in file_names[split_index:]:
        og_image = os.path.join(og_images, file_name)
        og_mask = os.path.join(og_masks, file_name)
        tr_image = os.path.join(tr_images, 'val', file_name)
        tr_mask = os.path.join(tr_masks, 'val', file_name)
        shutil.copyfile(og_image, tr_image)
        shutil.copyfile(og_mask, tr_mask)

og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'

tr_images = 'data/images'
tr_masks = 'data/masks'

copy_file(og_images, og_masks, tr_images, tr_masks, 0.9)

重新定义Dataset和Pipeline

  • 主要是修改类别及对应RGB配色
  • 以及dataload的路径信息
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
    # 类别和对应的RGB配色
    METAINFO = {
        'classes':['normal','sclerotic'],
        'palette':[[127,127,127],[251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,img_suffix='.png',
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
"""

with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:
    f.write(custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
                         BioMedical3DRandomCrop, BioMedical3DRandomFlip,
                         BioMedicalGaussianBlur, BioMedicalGaussianNoise,
                         BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
                         LoadBiomedicalAnnotation, LoadBiomedicalData,
                         LoadBiomedicalImageFromFile, LoadImageFromNDArray,
                         PackSegInputs, PhotoMetricDistortion, RandomCrop,
                         RandomCutOut, RandomMosaic, RandomRotate,
                         RandomRotFlip, Rerange, ResizeShortestEdge,
                         ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
__all__ = [
    'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
    'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
    'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
    'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
    'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
    'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
    'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
    'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
    'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
    'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
    'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
    'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
    'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
    'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
    'SynapseDataset', 'MyCustomDataset'
]

"""

with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:
    f.write(custom_init)
  • 定义数据预处理管道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)

# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/train', seg_map_path='masks/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/val', seg_map_path='masks/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
"""

with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:
    f.write(custom_pipeline)

修改配置文件

cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
  • 更改配置文件
# 类别个数
NUM_CLASS = 2
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size

# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4


# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloader


cfg.optimizer.lr = cfg.optimizer.lr / 8

# 结果保存目录
cfg.work_dir = './work_dirs'

cfg.train_cfg.max_iters = 40000 # 训练迭代次数
cfg.train_cfg.val_interval = 500 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 2500 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件,并开始训练
cfg.dump('custom_mask2former.py')
!python tools/train.py custom_mask2former.py

可视化训练指标

利用MMSegmentation微调Mask2Former模型_第4张图片

评估模型以及测试推理速度

  • 评估模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+-----------+-------+-------+-------+--------+-----------+--------+
|   Class   |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+-----------+-------+-------+-------+--------+-----------+--------+
|   normal  | 99.74 | 99.89 | 99.87 | 99.87  |   99.86   | 99.89  |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71  |   93.57   | 91.87  |
+-----------+-------+-------+-------+--------+-----------+--------+
  • 测试模型推理速度
# 测试FPS
!python tools/analysis_tools/benchmark.py custom_mask2former.py '{best_pth}'
  • 输出:
Done image [50 / 200], fps: 2.24 img / s
Done image [100/ 200], fps: 2.24 img / s
Done image [150/ 200], fps: 2.24 img / s
Done image [200/ 200], fps: 2.24 img / s
Overall fps: 2.24 img / s

Average fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0

你可能感兴趣的:(语义分割,mask2former,mmsegmentation)