【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)

文章目录

  • 0.mmSegmentation介绍
  • 1.mmSegmentation基本框架
    • 1.1.mmSegmentation的model设置
    • 1.2.mmSegmentation的dataset设置
      • 1.2.1.Dataset Class文件配置
      • 1.2.2.Dataset Config文件配置
      • 1.2.3.Total Config文件配置
  • 2.运行代码
  • 3.展示效果图和预测
  • X.附录
    • X.1.mmSegmentation框架解释
    • X.2.mmsegmentation使用的预训练backbone
    • X.2.mmsegmentation官方帮助文档

0.mmSegmentation介绍

\qquad mmSegmentation是openmmlab项目下开源的图像语义分割框架,目前支持pytorch,由于其拥有pipeline加速,完善的数据增强体系,完善的模型库,作为大数据语义分割训练及测试的代码框架是再好不过了。
\qquad 在开始本教程之前,你需要解决openmmlab的环境配置问题,好在这个repo上已经有很人性化的步骤讲解了,在此附上链接,就不赘述了:

  • Github链接:安装openmmlab环境

使用教程的相关链接如下(github的项目还自带了中文版):

  • Github链接:openmmlab/mmSegmentation
  • Gitio教程:openmmlab/mmSegmenatation

\qquad 对着mmSegmentation官方教程一步步做固然是能做出来,但是由于其框架结构过于复杂,加之官方教程对如何规范自定义数据集缺乏一些tips,因而本文提供了一个相对简单的教程供大家参考。本文所有讲解目录均为mmSegmentation的项目目录。
MMSegmentation

1.mmSegmentation基本框架

\qquad 要说mmSegmentation(以下简称mmSeg)当中最重要的东西,固然是Config文件了,Config文件可以分为4大类:

  1. model config
  2. dataset config
  3. runtime config
  4. schedule config

\qquad 如果你想知道为什么分成这四大类,请参考本文X.1.节,对这个不感兴趣就继续往下看。其实3和4大多数人都用不到的,重点还是在1和2,下面就从这两个角度给大家来一个不算精细的讲解。

1.1.mmSegmentation的model设置

\qquad 如果采用的是mmSegmentation里面支持的模型,那么固然是不需要自己写class了,自己挑一个模型就可以了。这些model的目录保存在了configs/models里面了。
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第1张图片
第一个下划线前面的都好理解,就是模型的名字呗,那r50-d8可能就是resnet的类型了,有人会问,那resnet101和resnet152哪去了,别急,其实这些只是baseline,它的backbone是可以改的,比如说我们要使用的是danet_r50-d8.py,我们先打开它(这里我已经将SyncBN改成了BN,因为需要单GPU训练):
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第2张图片
\qquad 只需要把model.backbone.depth设为101或者152就可以使用resnet101或者resnet152啦,如果你的本地没有模型,mmSeg就会从model_zoo里面下载一个,如果本地有(应该是保存在了checkpoint里面),则自动加载本地的,不会重复下载。其他的操作后面会讲,另外如果你是多GPU操作就选择使用SyncBN,否则就使用BN就可以了。如果使用了SyncBN却只有一块可用的GPU,那可能会报类似AssertionError:Default process group is not initialized的错误。有人可能问那我直接改了这个文件不就吧原来的默认参数给覆盖了嘛,不要紧,看到后面大家就会明白这个问题很容易解决,这里只是给大家做一个demo。

1.2.mmSegmentation的dataset设置

\qquad 数据集设置比model的稍微复杂一点,这里会直接定义一个自己的数据集(Custom Dataset)来说明其原理。数据集需要准备的文件有三个

  1. Dataset Class文件
  2. Dataset Config文件
  3. Total Config文件

\qquad 在X.1.节提到的config文件就是Total config(顶层设置文件),也是train.py文件直接调用的config文件,而Dataset Class文件是用来定义数据集的类别数和标签名称的,Dataset Config文件则是用来定义数据集目录、数据集信息(例如图片大小)、数据增强操作以及pipeline的。

1.2.1.Dataset Class文件配置

\qquad 首先来说Dataset Class文件,这个文件存放在 mmseg/datasets/ 目录下,
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第3张图片
\qquad 在这个目录下自己建一个数据集文件,并命个名。配置文件实际上是继承该目录下custom.py当中的CustomDataset父类的,这样写起了就简单多了,大多数情况下(当你的数据集是以一张张图片出现并且可用PIL模块读入时),你只需要设置两个参数即可——类别标签名称(CLASSES)和类别标签上色的RGB颜色(PALETTE)。以我的配置文件为例,代码如下:

from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
import os.path as osp

@DATASETS.register_module()
class MRDDataset(CustomDataset):
  CLASSES = ("background","road")
  PALETTE = [[0,0,0],[255,255,255]]
  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.png', seg_map_suffix='.png', 
                     split=split, **kwargs)
                     
    assert osp.exists(self.img_dir) and self.split is not None

\qquad img_suffixseg_map_suffix分别是你的数据集图片的后缀和标签图片的后缀,因个人差异而定,tif格式的图片我还没有试过,但是jpg和png的肯定是可以的。
\qquad 设置好之后记得保存在mmseg/datasets/目录下(我的文件名叫my_road_detect.py)。另外还需要设置一下该目录下的__init__文件:

from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .custom import CustomDataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .stare import STAREDataset
from .voc import PascalVOCDataset
from .my_road_detect import MRDDataset
__all__ = [
    'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
    'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
    'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
    'STAREDataset',"MRDDataset"
]

\qquad 需要改两个地方,①import的时候要把自己的Dataset加载进来,②__all__数组里面需要加入自己的Dataset类名称,修改完成之后保存。这两部操作完成之后还不行,由于训练的时候需要txt文件指示训练集、验证集和测试集的txt文件,一开始我以为这只是一个optional option,但无奈Custom Dataset的__init___下面给我来了一句assert osp.exists(self.img_dir) and self.split is not None,那好吧,不知道删了and后面的条件会有什么后果,还是自己创一个吧,写来一个简单的划分数据集并保存到txt的demo,大家可以把这个py文件放到你的数据集上一级目录上并对着稍微改改:

import mmcv
import os.path as osp
data_root = "/data3/datasets/Custom/Lab/Segmentation/"
ann_dir = "ann_png1"
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
    osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
  # select first 4/5 as train set
  train_length = int(len(filename_list)*4/5)
  f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
  # select last 1/5 as train set
  f.writelines(line + '\n' for line in filename_list[train_length:])

data_root写自己的工作目录名称,ann_dir写标签图片所在的目录,split_dir则是在data_root下生成split txt文件保存的文件夹目录,其他的就不需要怎么改了。如果你在data_root/split_dir/下成功找到了train.txt和val.txt文件,就没有问题了。

1.2.2.Dataset Config文件配置

\qquad Dataset Config文件在 configs/__base__/ 目录下,需要自己新建一个xxx.py文件。
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第4张图片
还是以我自己的Custom Dataset为例,它的书写格式如下:

# dataset settings
dataset_type = 'MRDDataset'
data_root = '/data3/datasets/Custom/Lab/Segmentation/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (640, 480)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(640, 480), keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(640, 480),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='data1_for_ann',
        ann_dir='ann_png1/',
        pipeline=train_pipeline,
        split="splits/train.txt"),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='data1_for_ann',
        ann_dir='ann_png1',
        split="splits/val.txt",
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='data1_for_ann',
        ann_dir='ann_png1',
        split="splits/val.txt",
        pipeline=test_pipeline))

需要改的地方有以下几个:

  1. img_norm_cfg:数据集的方差和均值
  2. crop_size:数据增强时裁剪的大小. img_dir:
  3. img_scale:原图像尺寸
  4. data_root:工作目录
  5. img_dir:工作目录下存图片的目录
  6. ann_dir:工作目录下存标签的目录
  7. split:之前操作做txt文件的目录
  8. sample_per_gpu:batch size
  9. workers_per_gpu:dataloader的线程数目,一般设2,4,8,根据CPU核数确定,或使用os.cpu_count()函数代替
  10. PhotoMetricDistortion是数据增强操作,有四个参数(参考博客)分别是亮度、对比度、饱和度和色调,它们的默认设定如下:
brightness_delta=32; # 32 
contrast_range=(0.5, 1.5); # (0.5, 1.5),下限-上限
saturation_range=(0.5, 1.5); # (0.5, 1.5),下限-上限
hue_delta=18; # 18

如果不想使用默认设定,仿照其他选项将自定义参数写在后面即可,例如

dict(type='PhotoMetricDistortion',contrast_range=(0.5, 1.0))

改好之后保存 configs/__base__/ 目录下。
\qquad 这里也给大家提供了计算数据集方差和均值的一个样例程序(多数据集计算整体均值和标准差):

# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 10:38:17 2021

@author: 17478
"""
import os
import cv2
import numpy as np
from tqdm import tqdm  # pip install tqdm
import argparse

def input_args():
    parser = argparse.ArgumentParser(description="calculating mean and std")
    parser.add_argument("--data_fmt",type=str,default='samples_{name}')
    parser.add_argument("--data-name",type=str,nargs="+",default=['morning','noon','afternoon','dusk','snowy'])
    return parser.parse_args()


if __name__ == "__main__":
    opt = input_args()
    img_files =[]
    for name in opt.data_name:
        img_dir = opt.data_fmt.format(name=name)
        files = os.listdir(img_dir)
        img_files.extend([os.path.join(img_dir,file) for file in files])

    meanRGB = np.asarray([0,0,0],dtype=np.float64)
    varRGB = np.asarray([0,0,0],dtype=np.float64)
    for img_file in tqdm(img_files,desc="calculating mean",mininterval=0.1):
        img = cv2.imread(img_file,-1)
        meanRGB[0] += np.mean(img[:,:,0])/255.0
        meanRGB[1] += np.mean(img[:,:,1])/255.0
        meanRGB[2] += np.mean(img[:,:,2])/255.0
    meanRGB = meanRGB/len(img_files)
    for img_file in tqdm(img_files,desc="calculating var",mininterval=0.1):
        img = cv2.imread(img_file,-1)
        varRGB[0] += np.sqrt(np.mean((img[:,:,0]/255.0-meanRGB[0])**2))
        varRGB[1] += np.sqrt(np.mean((img[:,:,1]/255.0-meanRGB[1])**2))
        varRGB[2] += np.sqrt(np.mean((img[:,:,2]/255.0-meanRGB[2])**2))
    varRGB = varRGB/len(img_files)
    print("meanRGB:{}".format(meanRGB))
    print("stdRGB:{}".format(varRGB))

1.2.3.Total Config文件配置

\qquad Total Config文件是train.py直接调用的config文件,在X.1.节也有介绍,在此只说明如何即可。该文件在 config/xxxmodel/ 的目录下,你选用的是哪一个model,就选择哪一个目录。
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第5张图片
以DANet为例,我们书写一个total config文件,并保存在configs/danet的文件夹下:

_base_ = [
    '../_base_/models/danet_r50-d8.py', '../_base_/datasets/my_road_detect.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
]
model = dict(
	decode_head=dict(num_classes=2),auxiliary_head=dict(num_classes=2))

\qquad 这个代码就一个__base__的数组,第一个元素代表模型路径,也就是在1.1.节介绍的模型文件(在这个教程里就不带着大家重写模型了);第二个元素代表数据集的Dataset config文件(详见1.2.2节);第三个元素和第四个元素本教程未涉及到,按照默认参数写也没有太大问题,如果想修改训练的代数以及log和save的频率修改第4元素及响应文件,在此就不再赘述了。另外如果你的模型不是19类的(因为是原模型是根据cityscapes写的,输出通道为19),需按照上面修改一下。
\qquad 到此为止要恭喜大家,代码终于可以试跑了,如果你的代码出现Error或者Exception也不要慌,从环境配置到流程一一对照一遍,调试大项目要有耐心,也欢迎大家评论区留言。

2.运行代码

\qquad 在项目目录下,输入python tools/train.py xxxconfig.py --work-dir=xxx即可运行,其中xxxconfig.py就是我们刚刚保存的Total config文件(记得要把完整路径也加上),work-dir其实就是保存log和model的目录(如果没有会自己创建)。如果发现import mmseg找不到这个包,那八成是调试器运行目录不在根目录下造成的,要不就配置run的目录,要不就直接吧tools/train.py复制到根目录下运行。运行结果差不多是这样:
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第6张图片
使用gpustat的包查看gpu状态
gpu
\qquad 虽然我的数据集很小(做测试的,就50张图片),但是gpu利用率仍然接近100%,可见其代码优化做的已经相当理想了。(我开了NVIDIA的图形加速,所以出现了很多其他的利用进程)。
\qquad 这里有读者会疑问为什么上面不显示epoch,因为mmseg默认是iteration-based的,所谓iteration即batch的个数,若要改成epoch,则需要参考docs/config.md进行修改:

runner = dict(type='EpochBasedRunner',
							max_epoch='200')
checkpoint_config = dict(by_epoch=True,
													interval=20)  # save checkpoint per 20 epochs

以上代码可放在Total config文件中。

3.展示效果图和预测

\qquad 最后写了展示预测效果的代码,把config_file和checkpoint_file替换成你自己的config文件和pth文件(保存模型的)即可:

from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette
config_file = "configs/danet/danet_r50-d8_360x480_20k_mrd.py"
checkpoint_file = 'work_dirs/danet_r50-d8_375x1242_20k_mrd/latest.pth'
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
img = '/data3/datasets/Custom/Lab/Segmentation/data1_for_ann/000000.png'
result = inference_segmentor(model, img)
show_result_pyplot(model, img, result, [[0,255,0],[255,255,255]])

【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第7张图片
\qquad 我上的是白色(道路)和绿色(非道路),不是特别好看,哈哈,但是mask和img的相对位置很容易看出来,这个配颜色的话,大家还是自己定吧。我这个数据集太少,只是给大家做个演示,结果肯定是过拟合的。

X.附录

X.1.mmSegmentation框架解释

在mmSegmentation的项目目录下,打开Configs/下面的目录
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第8张图片
随便打开一个文件就知道了
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第9张图片
从文件的名字也可以看出,它是模型(baseline+backbone、数据集、schedule的组合(runtime是default设置,就没包含在名称内)。

X.2.mmsegmentation使用的预训练backbone

预训练backbone下载链接为:
mmcv预训练模型下载地址(.json文件,复制对应模型的链接即可下载)
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第10张图片

X.2.mmsegmentation官方帮助文档

可在docs/tutorials中查看
【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)_第11张图片

希望本文对您有帮助,谢谢阅读!

你可能感兴趣的:(Python,深度学习,pytorch,mmsegmentation,openmmlab,语义分割)