MMSegmentation使用心得(一)

目录

  • 1 MMSegmentation简介
  • 2 MMSegmentation示例代码讲解
    • 2.1 MMSegmentation下载
    • 2.2 导入MMSegmentation包
    • 2.3 数据集导入
    • 2.4 配置文件(Config)
    • 2.5 模型训练
    • 2.6 预测

大家好,这是本人的第一篇文章,打算做一些针对于语义分割框架MMSegmentation的使用心得,避免大家踩坑,欢迎交流。

1 MMSegmentation简介

MMSegmentation是一个基于PyTorch的开源语义分割工具箱,是OpenMMLab项目的一部分。将语义分割框架分解为不同的组件,通过组合不同的模块,构建定制的语义分割框架。同时支持大部分流行语义分割模型,训练速度和精度上表现也很好。
由于之前保研的时候事情比较多,很久没有做相关工作,下面我们就先对官方公布的示例代码进行逐段分析,并讲解一些注意事项。

2 MMSegmentation示例代码讲解

2.1 MMSegmentation下载

首先我们要查看我们nvcc和gcc的一个版本,确定能否下载相关的包。

# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version

确定版本之后,我们需要先安装对应的mmcv的包,当然在这之前要先根据你的电脑NVIDIA版本,安装合适的pytorch。

# Install PyTorch
!pip install torch==1.12.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu113

安装pytorch可以直接上官网找到合适的版本,如果已经安装了pytorch,那么我们可以安装mmcv啦,记得去查看和pytorch以及CUDA版本对应的mmcv版本,如果版本不匹配会无法运行哦。具体匹配信息见链接。

# Install MMCV
!pip install openmim
!mim install mmcv-full==1.6.0

安装好这一切之后,我们就可以直接安装MMSegmentation啦!

!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git 
%cd mmsegmentation
!pip install -e .

2.2 导入MMSegmentation包

MMSegmentation的包在系统中的存储名字为mmseg。

import mmseg

2.3 数据集导入

示例代码中给出的是一些官方的数据集,当我们自己使用的时候其实改成自己的绝对路径就可以,data_root我个人一般直接忽视,只用image和lable数据集。对于classes,你有几类定义几类即可,palette可随意,影响不大。

import os.path as osp
import numpy as np
from PIL import Image
# convert dataset annotation to semantic segmentation map
data_root = 'iccv09Data'
img_dir = 'images'
ann_dir = 'labels'
# define class and plaette for better visualization
classes = ('sky', 'tree', 'road', 'grass', 'water', 'bldg', 'mntn', 'fg obj')
palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34], 
           [0, 11, 123], [118, 20, 12], [122, 81, 25], [241, 134, 51]]
for file in mmcv.scandir(osp.join(data_root, ann_dir), suffix='.regions.txt'):
  seg_map = np.loadtxt(osp.join(data_root, ann_dir, file)).astype(np.uint8)
  seg_img = Image.fromarray(seg_map).convert('P')
  seg_img.putpalette(np.array(palette, dtype=np.uint8))
  seg_img.save(osp.join(data_root, ann_dir, file.replace('.regions.txt', 
                                                         '.png')))                                                     

下面还要数据集划分的步骤,你可以自动划分,也可以提前划分好,要注意txt文件中存储的文件名是否带有后缀,以及记得调整代码中image和label的文件格式。

 split train/val set randomly
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
    osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
  # select first 4/5 as train set
  train_length = int(len(filename_list)*4/5)
  f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
  # select last 1/5 as train set
  f.writelines(line + '\n' for line in filename_list[train_length:])

**这里要讲如何定义自己的数据集,这是很重要的,因为我们在日常使用的时候大都使用自己的数据集,如果想使用常用的数据集如COCO,CITYSPACE等,直接调用即可。

rom mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

@DATASETS.register_module()
class StanfordBackgroundDataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette
  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.jpg', seg_map_suffix='.png', 
                     split=split, **kwargs)
    assert osp.exists(self.img_dir) and self.split is not None

    

我们根据数据的实际情况对期进行设置即可,主要关注数据类别和格式。

2.4 配置文件(Config)

配置文件即我们如何组装我们的模型,组合不同的网络架构及backbone,生成属于我们自己的模型。

这是调用已有config的一种方法

from mmcv import Config
cfg = Config.fromfile('configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')

在此基础上,原文对config又进行调整


from mmseg.apis import set_random_seed
from mmseg.utils import get_device

# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8

# Modify dataset type and path
cfg.dataset_type = 'StanfordBackgroundDataset'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 8
cfg.data.workers_per_gpu=8

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(320, 240),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/tutorial'

cfg.runner.max_iters = 200
cfg.log_config.interval = 10
cfg.evaluation.interval = 200
cfg.checkpoint_config.interval = 200

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
cfg.device = get_device()

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')

我们需要着重观察这几行
MMSegmentation使用心得(一)_第1张图片norm_cfg是很重要的一项,对于transformer,我们要将他改为SyncBN。同时要注意decode和辅助解码头的norm_cfg,有时候他们于backbone的是不一样的,即只用将backbone的norm_cfg改为SyncBN。后面就是我们最后要输出的类别,改成实际的即可。

在这之后就是数据增强的操作,我们可以对训练数据和测试数据进行不同的数据增强,以便模型又更好的表现。

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(320, 240),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

2.5 模型训练

全部设定好后我们可以根据配置文件,生成模型,并进行训练啦

from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(cfg.model)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

这里有时候会遇到不能运行的问题,我们可以对最后一行代码进行调整,在meta字典中加入classes和palette信息。

train_segmentor(model, datasets, cfg, distributed=False, validate=True,
                meta=dict(CLASSES=model.CLASSES,PALETTE=[[0, 0, 0], [255, 255, 255]]))

2.6 预测

这里是官方给出的预测代码,也可以用官方的函数进行结果展示

img = mmcv.imread('iccv09Data/images/6000124.jpg')

model.cfg = cfg
result = inference_segmentor(model, img)
plt.figure(figsize=(8, 6))
show_result_pyplot(model, img, result, palette)

训练结束后,我们导入训练好的参数就可以训练啦。这里输出的result其实就以及是类似图片的了,我们可以将它转化一些用PIL直接输出

result = inference_segmentor(model,img)
result = torch.tensor(np.array(result))
result.squeeze_()
result = result.numpy()

可能有同学会问训练参数是从哪里导入的呢,其实就是前面的config那里,刚开始训练其实可以把这里直接注释掉,训练过程中也可以继续训练,将这一行改成resume_from就可以接着训练啦!

cfg.load_from = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

示例代码的讲解到这里就结束了,大家赶快去试试吧!如果有什么问题欢迎评论区留言交流,我会尽快解答。

你可能感兴趣的:(python,深度学习,pytorch,图像处理,transformer)