MM Segmentation 训练自己的数据集

在/mmsegmentation-master/mmseg/datasets路径下新建一个mydataset.py文件

# Copyright (c) OpenMMLab. All rights reserved.

from .builder import DATASETS
from .custom import CustomDataset

@DATASETS.register_module()
class MyDataset(CustomDataset):
    """My Dataset.

    """
    CLASSES = ('background', 'alpha')  # 需要修改的地方
    PALETTE = [[0, 0, 0], [255, 255, 255]]  # 需要修改的地方

    def __init__(self, **kwargs):
        super(MyDatasetDataset, self).__init__(
            img_suffix='.png',  # 需要修改的地方
            seg_map_suffix='.png', #需要修改的地方
            reduce_zero_label=False,
            **kwargs)
        assert self.file_client.exists(self.img_dir)

同时在/mmsegmentation-master/mmseg/datasets中的____init____.py添加刚刚新建的数据集

# Copyright (c) OpenMMLab. All rights reserved.
from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .custom import CustomDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
                               RepeatDataset)
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .voc import PascalVOCDataset
from .mydataset import MyDataset # 需要修改的地方

__all__ = [
    'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
    'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
    'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
    'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
    'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
    'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'MyDataset', # 需要修改的地方
]

在/mmsegmentation-master/mmseg/core/evaluation下的class_names.py文件也需要进行更改:

def mydataset_classes():
    return ['background', 'alpha'] # 需要添加

def mydataset_palette():
    return [[0, 0, 0], [255, 255, 255]] # 需要添加

数据集更改完成,还需要更改config文件下的模型文件,可以在______base__/models目录下修改对应的模型文件,以deeplabv3_unet_s5-d16.py为例

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained=None,
    backbone=dict(
        type='UNet',
        in_channels=3,
        base_channels=64,
        num_stages=5,
        strides=(1, 1, 1, 1, 1),
        enc_num_convs=(2, 2, 2, 2, 2),
        dec_num_convs=(2, 2, 2, 2),
        downsamples=(True, True, True, True),
        enc_dilations=(1, 1, 1, 1, 1),
        dec_dilations=(1, 1, 1, 1),
        with_cp=False,
        conv_cfg=None,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='ReLU'),
        upsample_cfg=dict(type='InterpConv'),
        norm_eval=False),
    decode_head=dict(
        type='ASPPHead',
        in_channels=64,
        in_index=4,
        channels=16,
        dilations=(1, 12, 24, 36),
        dropout_ratio=0.1,
        num_classes=2, # 可能需要修改
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), # 可能需要修改
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=128,
        in_index=3,
        channels=64,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=2, # 可能需要修改
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # 可能需要修改
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='slide', crop_size=256, stride=170))

修改batch和train/val/test路径,在______base__/datasets目录下新建mydataset.py

# dataset settings
dataset_type = 'MyDataset' # 需要修改,对应mmseg/datasets中的____init____.py导入数据集
data_root = 'data/mydataset' # 需要修改,存放的数据集路径
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (2336, 3504) # 需要修改,指定图像的尺寸
crop_size = (256, 256)  # 需要修改,裁剪图像的尺寸
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']) 
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_scale,
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]

data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type='RepeatDataset',
        times=40000,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/training',
            ann_dir='annotations/training',
            pipeline=train_pipeline)),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline))

在configs/unet目录下新建一个deeplabv3_unet_s5_d16_256x256_40k_mydataset.py写配置文件

_base_ = [
    '../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/mydataset.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
model = dict(test_cfg=dict(crop_size=(256, 256), stride=(170, 170)))
evaluation = dict(metric='mDice')

然后进行训练

python3 ./tools/train.py ./configs/unet/deeplabv3_unet_s5_d16_256x256_40k_mydataset.py

MM Segmentation 训练自己的数据集_第1张图片

进行测试

python3 ./tools/test.py ./configs/unet/deeplabv3_unet_s5_d16_256x256_40k_mydataset.py ./work_dirs/deeplabv3_unet_s5_d16_256x256_40k_mydataset/epoch_1.pth --show-dir=./work_dirs/gs_results --with_depth=True

得到测试图片的效果

MM Segmentation 训练自己的数据集_第2张图片

你可能感兴趣的:(mmsegmentation,python,人工智能,深度学习)