PaddleSeg分割框架解读[01] 核心设计解析

文章目录

  • PaddleSeg分割框架解读[01] 核心设计解析
    • tools/train.py
    • paddleseg/cvlibs/config.py
    • paddleseg/cvlibs/builder.py
    • paddleseg/cvlibs/manager.py

PaddleSeg分割框架解读[01] 核心设计解析

tools/train.py

import argparse
import random
import numpy as np
import cv2

import paddle
from paddleseg.cvlibs import Config, SegBuilder
from paddleseg.utils import get_sys_env, logger, utils
from paddleseg.core import train


def parse_args():
    # 创建一个解析对象
    parser = argparse.ArgumentParser(description='Model training')
    # 添加要关注的命令行参数和选项
    # 训练和模型相关的配置文件
    parser.add_argument(
        "--config", 
        help="The config file.",
        type=str,
        default=None)
    # 训练设备
    parser.add_argument(
        '--device',
        help='Set the device place for training model.',
        type=str,
        default='gpu',
        choices=['cpu', 'gpu', 'xpu', 'npu', 'mlu'])
    # 训练权重保存路径
    parser.add_argument(
        '--save_dir',
        help='The directory for saving the model snapshot',
        type=str,
        default='./output')
    # 数据加载器的进程数num_workers
    parser.add_argument(
        '--num_workers',
        help='Num workers for data loader',
        type=int,
        default=0)
    # 是否边训练边验证
    parser.add_argument(
        '--do_eval',
        help='Eval while training',
        action='store_true')
    # 是否使用VisualDL进行可视化
    parser.add_argument(
        '--use_vdl',
        help='Whether to record the data to VisualDL during training',
        action='store_true')
    # 是否在训练中采用超参数进化
    parser.add_argument(
        '--use_ema',
        help='Whether to ema the model in training.',
        action='store_true')
    # 是否进行断点续练
    parser.add_argument(
        '--resume_model',
        help='The path of resume model',
        type=str,
        default=None)
    # 迭代次数
    parser.add_argument(
        '--iters',
        help='iters for training',
        type=int,
        default=None)
    # batch_size大小
    parser.add_argument(
        '--batch_size',
        help='Mini batch size of one gpu or cpu',
        type=int,
        default=None)
    # 初始学习率
    parser.add_argument(
        '--learning_rate',
        help='Learning rate',
        type=float,
        default=None)
    # 训练权重保存间隔
    parser.add_argument(
        '--save_interval',
        help='How many iters to save a model snapshot once during training.',
        type=int,
        default=1000)
    # 打印日志信息的间隔
    parser.add_argument(
        '--log_iters',
        help='Display logging information at every log_iters',
        type=int,
        default=10)
    # 最大保存的权重文件个数
    parser.add_argument(
        '--keep_checkpoint_max',
        help='Maximum number of checkpoints to save',
        type=int,
        default=20)
    # 随机种子设置
    parser.add_argument(
        '--seed',
        help='Set the random seed during training.',
        type=int,
        default=None)
    # 是否开启混合精度训练或正常训练
    parser.add_argument(
        "--precision",
        type=str,
        default="fp32",
        choices=["fp32", "fp16"],
        help="Use AMP (Auto mixed precision) if precision='fp16'. If precision='fp32', the training is normal."
    )
    # 自动混合精度水平
    parser.add_argument(
        "--amp_level",
        default="O1",
        type=str,
        choices=["O1", "O2"],
        help="Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, \
              the input data type of each operator will be casted by white_list and black_list; \
              O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, \
              except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp).")
    # 分析器的选择
    parser.add_argument(
        '--profiler_options',
        type=str,
        default=None,
        help='The option of train profiler. If profiler_options is not None, the train profiler is enabled' \
             'Refer to the paddleseg/utils/train_profiler.py for details.'
    )
    # 训练数据的格式"NCHW" or "NHWC"
    parser.add_argument(
        '--data_format',
        help='Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW".',
        type=str,
        default='NCHW')
    # 每个epoch中重复采样数据集的次数
    parser.add_argument(
        '--repeats',
        type=int,
        default=1,
        help="Repeat the samples in the dataset for `repeats` times in each epoch."
    )
    # nargs是用来说明传入的参数个数,'+' 表示传入至少一个参数。
    # 一种是定义nargs='?',可选项出现在命令行中,但之后并没有跟随赋值的参数,作为默认值传给此可选项。
    # 更新所有选项的键值对key-value
    parser.add_argument(
        '--opts', 
        help='Update the key-value pairs of all options.', 
        nargs='+')
    
    # 进行解析并获得传入的参数
    return parser.parse_args()


def main(args):
    # 必须指定配置文件config
    assert args.config is not None, 'No configuration file specified, please set --config'
    # 新的参数配置文件
    cfg = Config(
        args.config,
        learning_rate=args.learning_rate,
        iters=args.iters,
        batch_size=args.batch_size,
        opts=args.opts)
    builder = SegBuilder(cfg)
    
    utils.show_env_info()
    utils.show_cfg_info(cfg)
    utils.set_seed(args.seed)
    utils.set_device(args.device)
    utils.set_cv2_num_threads(args.num_workers)
    
    # 数据格式NHWC仅仅支持DeepLabv3+模型
    if args.data_format == 'NHWC':
        if cfg.dic['model']['type'] != 'DeepLabV3P':
            raise ValueError('The "NHWC" data format only support the DeepLabV3P model!')
        # 相关涉及到data_format的都需要进行修改
        cfg.dic['model']['data_format'] = args.data_format
        cfg.dic['model']['backbone']['data_format'] = args.data_format
        loss_len = len(cfg.dic['loss']['types'])
        for i in range(loss_len):
            cfg.dic['loss']['types'][i]['data_format'] = args.data_format
    
    model = utils.convert_sync_batchnorm(builder.model, args.device)
    
    # 训练数据集
    train_dataset = builder.train_dataset
    # 数据集重复次数
    if args.repeats > 1:
        train_dataset.file_list *= args.repeats
    # 验证数据集
    val_dataset = builder.val_dataset if args.do_eval else None
    # 优化器
    optimizer = builder.optimizer
    # 损失函数
    loss = builder.loss
    
    train(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        optimizer=optimizer,
        save_dir=args.save_dir,
        iters=cfg.iters,
        batch_size=cfg.batch_size,
        resume_model=args.resume_model,
        save_interval=args.save_interval,
        log_iters=args.log_iters,
        num_workers=args.num_workers,
        use_vdl=args.use_vdl,
        use_ema=args.use_ema,
        losses=loss,
        keep_checkpoint_max=args.keep_checkpoint_max,
        test_config=cfg.test_config,
        precision=args.precision,
        amp_level=args.amp_level,
        profiler_options=args.profiler_options,
        to_static_training=cfg.to_static_training)


if __name__ == '__main__':
    args = parse_args()
    main(args)

paddleseg/cvlibs/config.py

import six
import codecs
import os
from ast import literal_eval
from typing import Any, Dict, Optional
import yaml

import paddle
from paddleseg.cvlibs import config_checker as checker
from paddleseg.cvlibs import manager
from paddleseg.utils import logger, utils

_INHERIT_KEY = '_inherited_'
_BASE_KEY = '_base_'


class Config(object):
    """
    参数配置文件解析,仅仅支持yaml/yml文件。

    参数文件中的超参数hyper-parameters:
        batch_size: 每个gpu的样本数量。
        iters: 训练总共的迭代次数。
        train_dataset: 训练数据的配置,包括type/data_root/transforms/mode。
            For data type, please refer to paddleseg.datasets.(数据类型的参考)
            For specific transforms, please refer to paddleseg.transforms.transforms.(数据增强的参考)
        val_dataset: 验证数据的配置,包括type/data_root/transforms/mode。
        optimizer: 优化器的配置,请参考paddleseg.optimizers。
        learning_rate: 学习率的配置。 如果有衰减的设置,learning_rate值代表初始学习率,目前仅支持多项式衰减poly decay. 
                       decay power衰减率和end_lr最终学习率需要根据实验调整。
        loss: 损失函数的配置,多种损失函数Multi-loss。 
            损失函数类型的顺序必须和分割模型的输出一致,其中coef项表示相应损失的权重,注意coef的数量必须和模型输出的数量一样。 
            如果在输出中使用相同的损失类型,则可能只有一种损失类型,否则损失类型的数量必须与coef的数量一致。
        model: 模型的配置,包括type/backbone和model-dependent arguments.
            模型类型model,参考paddleseg.models。
            骨干网络类型backbone,参考paddleseg.models.backbones。

    Args:
        path (str) : config文件路径, 仅支持yaml格式。
        opts (list, optional): 使用opts去更新所有选项的键值对key-value。
        
    Examples:

        from paddleseg.cvlibs.config import Config

        # Create a cfg object with yaml file path.
        cfg = Config(yaml_cfg_path)

        # Parsing the argument when its property is used.
        train_dataset = cfg.train_dataset

        # the argument of model should be parsed after dataset,
        # since the model builder uses some properties in dataset.
        model = cfg.model
        ...
    """
    def __init__(
            self,
            path: str,
            learning_rate: Optional[float]=None,
            batch_size: Optional[int]=None,
            iters: Optional[int]=None,
            opts: Optional[list]=None,
            checker: Optional[checker.ConfigChecker]=None, ):
        assert os.path.exists(path), \
            'Config path ({}) does not exist'.format(path)
        assert path.endswith('yml') or path.endswith('yaml'), \
            'Config file ({}) should be yaml format'.format(path)

        # 将yaml文件解析成字典dict
        self.dic = self._parse_from_yaml(path)
        # 根据传参来进行配置文件yaml字典dict的更新
        self.dic = self.update_config_dict(
            self.dic,
            learning_rate=learning_rate,
            batch_size=batch_size,
            iters=iters,
            opts=opts)

        if checker is None:
            checker = self._build_default_checker()
        checker.apply_all_rules(self)

    # 使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,
    # 可以与所定义的属性配合使用,这样可以防止属性被修改。
    @property
    def batch_size(self) -> int:
        return self.dic.get('batch_size')

    @property
    def iters(self) -> int:
        return self.dic.get('iters')

    @property
    def to_static_training(self) -> bool:
        return self.dic.get('to_static_training', False)

    @property
    def model_cfg(self) -> Dict:
        return self.dic.get('model', {}).copy()

    @property
    def loss_cfg(self) -> Dict:
        return self.dic.get('loss', {}).copy()

    @property
    def distill_loss_cfg(self) -> Dict:
        return self.dic.get('distill_loss', {}).copy()

    @property
    def lr_scheduler_cfg(self) -> Dict:
        return self.dic.get('lr_scheduler', {}).copy()

    @property
    def optimizer_cfg(self) -> Dict:
        return self.dic.get('optimizer', {}).copy()

    @property
    def train_dataset_cfg(self) -> Dict:
        return self.dic.get('train_dataset', {}).copy()

    @property
    def val_dataset_cfg(self) -> Dict:
        return self.dic.get('val_dataset', {}).copy()

    # TODO merge test_config into val_dataset
    @property
    def test_config(self) -> Dict:
        return self.dic.get('test_config', {}).copy()

    @classmethod
    def update_config_dict(cls, dic: dict, *args, **kwargs) -> dict:
        return update_config_dict(dic, *args, **kwargs)

    # 在Python中,@classmethod装饰器用于将类中的方法声明为可以使用ClassName.MethodName()调用的类方法,
    # 也可以使用类的对象调用类方法。
    @classmethod
    def _parse_from_yaml(cls, path: str, *args, **kwargs) -> dict:
        return parse_from_yaml(path, *args, **kwargs)

    @classmethod
    def _build_default_checker(cls):
        rules = []
        rules.append(checker.DefaultPrimaryRule())
        rules.append(checker.DefaultSyncNumClassesRule())
        rules.append(checker.DefaultSyncImgChannelsRule())
        # Losses
        rules.append(checker.DefaultLossRule('loss'))
        rules.append(checker.DefaultSyncIgnoreIndexRule('loss'))
        # Distillation losses
        rules.append(checker.DefaultLossRule('distill_loss'))
        rules.append(checker.DefaultSyncIgnoreIndexRule('distill_loss'))

        return checker.ConfigChecker(rules, allow_update=True)

    def __str__(self) -> str:
        # Use NoAliasDumper to avoid yml anchor 
        return yaml.dump(self.dic, Dumper=utils.NoAliasDumper)


def parse_from_yaml(path: str):
    """
    递归地解析yaml文件并构建配置config
    """
    # 读取yaml文件,并转为字典dict形式
    with codecs.open(path, 'r', 'utf-8') as file:
        dic = yaml.load(file, Loader=yaml.FullLoader)

    if _BASE_KEY in dic:
        # pop()方法删除字典给定键key所对应的值,返回值为被删除的值。
        # 继承的基本路径
        base_files = dic.pop(_BASE_KEY)
        if isinstance(base_files, str):
            base_files = [base_files]
        # 对于继承的每一个base文件来说
        for bf in base_files:
            # os.path.dirname(path):去掉文件名,返回目录。
            base_path = os.path.join(os.path.dirname(path), bf)
            # 解析继承的_base_文件yaml
            base_dic = parse_from_yaml(base_path)
            # 更新字典
            dic = merge_config_dicts(dic, base_dic)
    return dic


def merge_config_dicts(dic, base_dic):
    """
    将dic合并到base_dic
    """
    base_dic = base_dic.copy() # 浅复制
    dic = dic.copy() # 浅复制

    # 判断dic是否继承
    if not dic.get(_INHERIT_KEY, True):
        dic.pop(_INHERIT_KEY)
        return dic

    # 循环遍历字典dic里面的key, value
    for key, val in dic.items():
        # 如果发现val是一个字典,并且key在base_dic里面,就继续更新。
        if isinstance(val, dict) and key in base_dic:
            base_dic[key] = merge_config_dicts(val, base_dic[key])
        else:
            base_dic[key] = val

    return base_dic


def update_config_dict(dic: dict,
                       learning_rate: Optional[float]=None,
                       batch_size: Optional[int]=None,
                       iters: Optional[int]=None,
                       opts: Optional[list]=None):
    """Update config"""
    # TODO: If the items to update are marked as anchors in the yaml file,
    # we should synchronize the references.
    dic = dic.copy()

    if learning_rate:
        dic['lr_scheduler']['learning_rate'] = learning_rate
    if batch_size:
        dic['batch_size'] = batch_size
    if iters:
        dic['iters'] = iters

    if opts is not None:
        for item in opts:
            assert ('=' in item) and (len(item.split('=')) == 2), "--opts params should be key=value," \
                " such as `--opts batch_size=1 test_config.scales=0.75,1.0,1.25`, " \
                "but got ({})".format(opts)

            key, value = item.split('=')
            if isinstance(value, six.string_types):
                try:
                    value = literal_eval(value)
                except ValueError:
                    pass
                except SyntaxError:
                    pass
            key_list = key.split('.')

            tmp_dic = dic
            for subkey in key_list[:-1]:
                assert subkey in tmp_dic, "Can not update {}, because it is not in config.".format(key)
                tmp_dic = tmp_dic[subkey]
            tmp_dic[key_list[-1]] = value

    return dic

paddleseg/cvlibs/builder.py

import copy
from typing import Any, Optional
import yaml
import paddle

from paddleseg.cvlibs import manager, Config
from paddleseg.utils import utils, logger
from paddleseg.utils.utils import CachedProperty as cached_property


class Builder(object):
    """
    用于生成组件的基类 

    Args:
        config (Config): Config类对象。
        comp_list (list, optional): 组件类的列表。Default: None
    """
    def __init__(self, config: Config, comp_list: Optional[list]=None):
        super().__init__()
        self.config = config
        self.comp_list = comp_list
    
    # {'type': 'MixedLoss', 'losses': [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}], 'coef': [0.4, 0.6]}

    def build_component(self, cfg):
        """
        Create Python object, such as model, loss, dataset, etc.
        """
        # copy.copy()是浅拷贝,只拷贝父对象,不会拷贝对象的内部的子对象。
        # copy.deepcopy()是深拷贝,会拷贝对象及其子对象,哪怕以后对其有改动,也不会影响其第一次的拷贝。
        cfg = copy.deepcopy(cfg)
        if 'type' not in cfg:
            raise RuntimeError(
                "It is not possible to create a component object from {}, as 'type' is not specified.".format(cfg)
                )
        # 类的类型
        class_type = cfg.pop('type')
        # 加载组件类
        com_class = self.load_component_class(class_type)
        # 参数字典
        params = {}
        for key, val in cfg.items():
            if self.is_meta_type(val):
                params[key] = self.build_component(val)
            elif isinstance(val, list):
                params[key] = [
                    self.build_component(item)
                    if self.is_meta_type(item) else item for item in val
                ]
            else:
                params[key] = val
                
        # 组件类的实例化
        try:
            obj = self.build_component_impl(com_class, **params)
        except Exception as e:
            if hasattr(com_class, '__name__'):
                com_name = com_class.__name__
            else:
                com_name = ''
            raise RuntimeError(
                f"Tried to create a {com_name} object, but the operation has failed. "
                "Please double check the arguments used to create the object.\n"
                f"The error message is: \n{str(e)}")

        return obj

    def build_component_impl(self, component_class, *args, **kwargs):
        return component_class(*args, **kwargs)

    def load_component_class(self, class_type):
        for com in self.comp_list:
            if class_type in com.components_dict:
                return com[class_type]
        raise RuntimeError("The specified component ({}) was not found.".format(class_type))

    @classmethod
    def is_meta_type(cls, obj):
        # TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
        # to make it more pythonic?
        return isinstance(obj, dict) and 'type' in obj

    @classmethod
    def show_msg(cls, name, cfg):
        msg = 'Use the following config to build {}\n'.format(name)
        msg += str(yaml.dump({name: cfg}, Dumper=utils.NoAliasDumper))
        logger.info(msg[0:-1])


class SegBuilder(Builder):
    """
    此类负责构建用于语义分割的部件。 
    """
    def __init__(self, config, comp_list=None):
        # 组件管理器列表
        if comp_list is None:
            comp_list = [
                manager.MODELS, manager.BACKBONES, manager.DATASETS,
                manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS
            ]
        super().__init__(config, comp_list)

    # @cached_property缓存装饰器,使用cached_property修饰过的函数,变成了对象的属性。
    # 当第一次引用该属性时,会调用该函数,以后再调用该属性时,会直接从字典中取。
    @cached_property
    def model(self) -> paddle.nn.Layer:
        model_cfg = self.config.model_cfg
        assert model_cfg != {}, 'No model specified in the configuration file.'

        if self.config.train_dataset_cfg['type'] != 'Dataset':
            # 检查并同步模型配置model config和数据集类dataset class中的num_classes
            assert hasattr(self.train_dataset_class, 'NUM_CLASSES'), \
                'If train_dataset class is not `Dataset`, it must have `NUM_CLASSES` attr.'
            num_classes = getattr(self.train_dataset_class, 'NUM_CLASSES')
            if 'num_classes' in model_cfg:
                assert model_cfg['num_classes'] == num_classes, \
                    'The num_classes is not consistent for model config ({}) ' \
                    'and train_dataset class ({}) '.format(model_cfg['num_classes'], num_classes)
            else:
                logger.warning(
                    'Add the `num_classes` in train_dataset class to model config.'
                    'We suggest you manually set `num_classes` in model config.'
                )
                model_cfg['num_classes'] = num_classes
            
            # 检查并同步模型配置model config和数据集类dataset class中的in_channels
            assert hasattr(self.train_dataset_class, 'IMG_CHANNELS'), \
                'If train_dataset class is not `Dataset`, it must have `IMG_CHANNELS` attr.'
            in_channels = getattr(self.train_dataset_class, 'IMG_CHANNELS')
            x = utils.get_in_channels(model_cfg)
            if x is not None:
                assert x == in_channels, \
                    'The in_channels in model config ({}) and the img_channels in train_dataset ' \
                    'class ({}) is not consistent'.format(x, in_channels)
            else:
                model_cfg = utils.set_in_channels(model_cfg, in_channels)
                logger.warning(
                    'Add the `in_channels` in train_dataset class to model config.'
                    'We suggest you manually set `in_channels` in model config.'
                )
        # 信息打印
        self.show_msg('model', model_cfg)
        return self.build_component(model_cfg)

    @cached_property
    def optimizer(self) -> paddle.optimizer.Optimizer:
        opt_cfg = self.config.optimizer_cfg
        assert opt_cfg != {}, 'No optimizer specified in the configuration file.'
        # For compatibility
        if opt_cfg['type'] == 'adam':
            opt_cfg['type'] = 'Adam'
        if opt_cfg['type'] == 'sgd':
            opt_cfg['type'] = 'SGD'
        if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg:
            opt_cfg['type'] = 'Momentum'
            logger.info('If the type is SGD and momentum in optimizer config, '
                        'the type is changed to Momentum.')
        self.show_msg('optimizer', opt_cfg)
        opt = self.build_component(opt_cfg)
        opt = opt(self.model, self.lr_scheduler)
        return opt

    @cached_property
    def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler:
        lr_cfg = self.config.lr_scheduler_cfg
        assert lr_cfg != {}, 'No lr_scheduler specified in the configuration file.'

        use_warmup = False
        if 'warmup_iters' in lr_cfg:
            use_warmup = True
            warmup_iters = lr_cfg.pop('warmup_iters')
            assert 'warmup_start_lr' in lr_cfg, \
                "When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler"
            warmup_start_lr = lr_cfg.pop('warmup_start_lr')
            end_lr = lr_cfg['learning_rate']

        lr_type = lr_cfg.pop('type')
        if lr_type == 'PolynomialDecay':
            iters = self.config.iters - warmup_iters if use_warmup else self.config.iters
            iters = max(iters, 1)
            lr_cfg.setdefault('decay_steps', iters)

        try:
            lr_sche = getattr(paddle.optimizer.lr, lr_type)(**lr_cfg)
        except Exception as e:
            raise RuntimeError(
                "Create {} has failed. Please check lr_scheduler in config. "
                "The error message: {}".format(lr_type, e))

        if use_warmup:
            lr_sche = paddle.optimizer.lr.LinearWarmup(
                learning_rate=lr_sche,
                warmup_steps=warmup_iters,
                start_lr=warmup_start_lr,
                end_lr=end_lr)

        return lr_sche

    @cached_property
    def loss(self) -> dict:
        loss_cfg = self.config.loss_cfg
        assert loss_cfg != {}, 'No loss specified in the configuration file.'
        return self._build_loss('loss', loss_cfg)

    @cached_property
    def distill_loss(self) -> dict:
        loss_cfg = self.config.distill_loss_cfg
        assert loss_cfg != {}, 'No distill_loss specified in the configuration file.'
        return self._build_loss('distill_loss', loss_cfg)

    def _build_loss(self, loss_name, loss_cfg: dict):
        def _check_helper(loss_cfg, ignore_index):
            if 'ignore_index' not in loss_cfg:
                loss_cfg['ignore_index'] = ignore_index
                logger.warning('Add the `ignore_index` in train_dataset class to {} config.' \
                    'We suggest you manually set `ignore_index` in {} config.'.format(loss_name, loss_name)
                )
            else:
                assert loss_cfg['ignore_index'] == ignore_index, \
                    'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, '\
                    'train_dataset ignore_index = {}'.format(loss_cfg['ignore_index'], ignore_index)

        # 检查并同步模型配置model config和数据集类dataset class中的ignore_index
        if self.config.train_dataset_cfg['type'] != 'Dataset':
            assert hasattr(self.train_dataset_class, 'IGNORE_INDEX'), \
                'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.'
            ignore_index = getattr(self.train_dataset_class, 'IGNORE_INDEX')
            for loss_cfg_i in loss_cfg['types']:
                if loss_cfg_i['type'] == 'MixedLoss':
                    # [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}]
                    for loss_cfg_j in loss_cfg_i['losses']:
                        _check_helper(loss_cfg_j, ignore_index)
                else:
                    _check_helper(loss_cfg_i, ignore_index)
        # 信息打印
        self.show_msg(loss_name, loss_cfg)
        loss_dict = {'coef': loss_cfg['coef'], "types": []}
        # {'type': 'MixedLoss', 'losses': [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}], 'coef': [0.4, 0.6]}
        for item in loss_cfg['types']:
            loss_dict['types'].append(self.build_component(item))
        
        return loss_dict

    @cached_property
    def train_dataset(self) -> paddle.io.Dataset:
        dataset_cfg = self.config.train_dataset_cfg
        assert dataset_cfg != {}, 'No train_dataset specified in the configuration file.'
        self.show_msg('train_dataset', dataset_cfg)
        dataset = self.build_component(dataset_cfg)
        assert len(dataset) != 0, \
            'The number of samples in train_dataset is 0. Please check whether the dataset is valid.'
        return dataset

    @cached_property
    def val_dataset(self) -> paddle.io.Dataset:
        dataset_cfg = self.config.val_dataset_cfg
        assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
        self.show_msg('val_dataset', dataset_cfg)
        dataset = self.build_component(dataset_cfg)
        if len(dataset) == 0:
            logger.warning('The number of samples in val_dataset is 0. Please ensure this is the desired behavior.')
        return dataset

    @cached_property
    def train_dataset_class(self) -> Any:
        dataset_cfg = self.config.train_dataset_cfg
        assert dataset_cfg != {}, 'No train_dataset specified in the configuration file.'
        dataset_type = dataset_cfg.get('type')
        return self.load_component_class(dataset_type)

    @cached_property
    def val_dataset_class(self) -> Any:
        dataset_cfg = self.config.val_dataset_cfg
        assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
        dataset_type = dataset_cfg.get('type')
        return self.load_component_class(dataset_type)

    @cached_property
    def val_transforms(self) -> list:
        dataset_cfg = self.config.val_dataset_cfg
        assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
        transforms = []
        for item in dataset_cfg.get('transforms', []):
            transforms.append(self.build_component(item))
        return transforms

paddleseg/cvlibs/manager.py

特别注意,这块具体实现的类,如class Cityscapes(Dataset)等,称为组件;
组件管理器,则为相应的模型model管理器、数据集datasets管理器等。

import inspect
from collections.abc import Sequence

import warnings


class ComponentManager:
    """
    组件管理器类
    实现管理器类以正确添加新的组件,组件可以被添加作为类或函数类型。

    Args:
        name (str): The name of component.

    Returns:
        A callable object of ComponentManager.

    Examples 1:

        from paddleseg.cvlibs.manager import ComponentManager

        model_manager = ComponentManager()

        class AlexNet: ...
        class ResNet: ...

        model_manager.add_component(AlexNet)
        model_manager.add_component(ResNet)

        # Or pass a sequence alliteratively:
        model_manager.add_component([AlexNet, ResNet])
        print(model_manager.components_dict)
        # {'AlexNet': , 'ResNet': }

    Examples 2:

        # Or an easier way, using it as a Python decorator, while just add it above the class declaration.
        from paddleseg.cvlibs.manager import ComponentManager

        model_manager = ComponentManager()

        @model_manager.add_component
        class AlexNet: ...

        @model_manager.add_component
        class ResNet: ...

        print(model_manager.components_dict)
        # {'AlexNet': , 'ResNet': }
    """
    def __init__(self, name=None):
        self._components_dict = dict()
        self._name = name

    def __len__(self):
        return len(self._components_dict)

    def __repr__(self):
        name_str = self._name if self._name else self.__class__.__name__
        return "{}:{}".format(name_str, list(self._components_dict.keys()))

    def __getitem__(self, item):
        if item not in self._components_dict.keys():
            raise KeyError("{} does not exist in availabel {}".format(item, self))
        return self._components_dict[item]

    @property
    def components_dict(self):
        return self._components_dict

    @property
    def name(self):
        return self._name

    def _add_single_component(self, component):
        """
        将单个组件添加到相应的管理器中。(如,模型管理器)
        Args:
            component (function|class): A new component.
        Raises:
            TypeError: When `component` is neither class nor function.
            KeyError: When `component` was added already.
        """
        # 目前仅仅支持类class和函数function类型
        if not (inspect.isclass(component) or inspect.isfunction(component)):
            raise TypeError("Expect class/function type, but received {}".format(type(component)))

        # 获取组件的内部名称
        component_name = component.__name__

        # 检查这个组件是否已经被添加
        # 以组件的内部名称为键
        if component_name in self._components_dict.keys():
            warnings.warn("{} exists already! It is now updated to {} !!!".format(component_name, component))
            self._components_dict[component_name] = component
        else:
            self._components_dict[component_name] = component

    def add_component(self, components):
        """
        将组件添加到相应的管理器中。
        Args:
            components (function|class|list|tuple): Support four types of components.

        Returns:
            components (function|class|list|tuple): Same with input components.
        """
        # 判断这个组件components是否为序列
        if isinstance(components, Sequence):
            for component in components:
                self._add_single_component(component)
        else:
            component = components
            self._add_single_component(component)

        return components


# 模型model管理器
MODELS = ComponentManager("models")
# 骨干网络backbone管理器
BACKBONES = ComponentManager("backbones")
# 数据集datasets管理器
DATASETS = ComponentManager("datasets")
# 数据增强transforms管理器
TRANSFORMS = ComponentManager("transforms")
# 损失函数losses管理器
LOSSES = ComponentManager("losses")
# 优化器optimizers管理器
OPTIMIZERS = ComponentManager("optimizers")

你可能感兴趣的:(PaddleSeg使用及其解析,深度学习,人工智能,神经网络,python)