import argparse
import random
import numpy as np
import cv2
import paddle
from paddleseg.cvlibs import Config, SegBuilder
from paddleseg.utils import get_sys_env, logger, utils
from paddleseg.core import train
def parse_args():
# 创建一个解析对象
parser = argparse.ArgumentParser(description='Model training')
# 添加要关注的命令行参数和选项
# 训练和模型相关的配置文件
parser.add_argument(
"--config",
help="The config file.",
type=str,
default=None)
# 训练设备
parser.add_argument(
'--device',
help='Set the device place for training model.',
type=str,
default='gpu',
choices=['cpu', 'gpu', 'xpu', 'npu', 'mlu'])
# 训练权重保存路径
parser.add_argument(
'--save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
# 数据加载器的进程数num_workers
parser.add_argument(
'--num_workers',
help='Num workers for data loader',
type=int,
default=0)
# 是否边训练边验证
parser.add_argument(
'--do_eval',
help='Eval while training',
action='store_true')
# 是否使用VisualDL进行可视化
parser.add_argument(
'--use_vdl',
help='Whether to record the data to VisualDL during training',
action='store_true')
# 是否在训练中采用超参数进化
parser.add_argument(
'--use_ema',
help='Whether to ema the model in training.',
action='store_true')
# 是否进行断点续练
parser.add_argument(
'--resume_model',
help='The path of resume model',
type=str,
default=None)
# 迭代次数
parser.add_argument(
'--iters',
help='iters for training',
type=int,
default=None)
# batch_size大小
parser.add_argument(
'--batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=None)
# 初始学习率
parser.add_argument(
'--learning_rate',
help='Learning rate',
type=float,
default=None)
# 训练权重保存间隔
parser.add_argument(
'--save_interval',
help='How many iters to save a model snapshot once during training.',
type=int,
default=1000)
# 打印日志信息的间隔
parser.add_argument(
'--log_iters',
help='Display logging information at every log_iters',
type=int,
default=10)
# 最大保存的权重文件个数
parser.add_argument(
'--keep_checkpoint_max',
help='Maximum number of checkpoints to save',
type=int,
default=20)
# 随机种子设置
parser.add_argument(
'--seed',
help='Set the random seed during training.',
type=int,
default=None)
# 是否开启混合精度训练或正常训练
parser.add_argument(
"--precision",
type=str,
default="fp32",
choices=["fp32", "fp16"],
help="Use AMP (Auto mixed precision) if precision='fp16'. If precision='fp32', the training is normal."
)
# 自动混合精度水平
parser.add_argument(
"--amp_level",
default="O1",
type=str,
choices=["O1", "O2"],
help="Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, \
the input data type of each operator will be casted by white_list and black_list; \
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, \
except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp).")
# 分析器的选择
parser.add_argument(
'--profiler_options',
type=str,
default=None,
help='The option of train profiler. If profiler_options is not None, the train profiler is enabled' \
'Refer to the paddleseg/utils/train_profiler.py for details.'
)
# 训练数据的格式"NCHW" or "NHWC"
parser.add_argument(
'--data_format',
help='Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW".',
type=str,
default='NCHW')
# 每个epoch中重复采样数据集的次数
parser.add_argument(
'--repeats',
type=int,
default=1,
help="Repeat the samples in the dataset for `repeats` times in each epoch."
)
# nargs是用来说明传入的参数个数,'+' 表示传入至少一个参数。
# 一种是定义nargs='?',可选项出现在命令行中,但之后并没有跟随赋值的参数,作为默认值传给此可选项。
# 更新所有选项的键值对key-value
parser.add_argument(
'--opts',
help='Update the key-value pairs of all options.',
nargs='+')
# 进行解析并获得传入的参数
return parser.parse_args()
def main(args):
# 必须指定配置文件config
assert args.config is not None, 'No configuration file specified, please set --config'
# 新的参数配置文件
cfg = Config(
args.config,
learning_rate=args.learning_rate,
iters=args.iters,
batch_size=args.batch_size,
opts=args.opts)
builder = SegBuilder(cfg)
utils.show_env_info()
utils.show_cfg_info(cfg)
utils.set_seed(args.seed)
utils.set_device(args.device)
utils.set_cv2_num_threads(args.num_workers)
# 数据格式NHWC仅仅支持DeepLabv3+模型
if args.data_format == 'NHWC':
if cfg.dic['model']['type'] != 'DeepLabV3P':
raise ValueError('The "NHWC" data format only support the DeepLabV3P model!')
# 相关涉及到data_format的都需要进行修改
cfg.dic['model']['data_format'] = args.data_format
cfg.dic['model']['backbone']['data_format'] = args.data_format
loss_len = len(cfg.dic['loss']['types'])
for i in range(loss_len):
cfg.dic['loss']['types'][i]['data_format'] = args.data_format
model = utils.convert_sync_batchnorm(builder.model, args.device)
# 训练数据集
train_dataset = builder.train_dataset
# 数据集重复次数
if args.repeats > 1:
train_dataset.file_list *= args.repeats
# 验证数据集
val_dataset = builder.val_dataset if args.do_eval else None
# 优化器
optimizer = builder.optimizer
# 损失函数
loss = builder.loss
train(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
optimizer=optimizer,
save_dir=args.save_dir,
iters=cfg.iters,
batch_size=cfg.batch_size,
resume_model=args.resume_model,
save_interval=args.save_interval,
log_iters=args.log_iters,
num_workers=args.num_workers,
use_vdl=args.use_vdl,
use_ema=args.use_ema,
losses=loss,
keep_checkpoint_max=args.keep_checkpoint_max,
test_config=cfg.test_config,
precision=args.precision,
amp_level=args.amp_level,
profiler_options=args.profiler_options,
to_static_training=cfg.to_static_training)
if __name__ == '__main__':
args = parse_args()
main(args)
import six
import codecs
import os
from ast import literal_eval
from typing import Any, Dict, Optional
import yaml
import paddle
from paddleseg.cvlibs import config_checker as checker
from paddleseg.cvlibs import manager
from paddleseg.utils import logger, utils
_INHERIT_KEY = '_inherited_'
_BASE_KEY = '_base_'
class Config(object):
"""
参数配置文件解析,仅仅支持yaml/yml文件。
参数文件中的超参数hyper-parameters:
batch_size: 每个gpu的样本数量。
iters: 训练总共的迭代次数。
train_dataset: 训练数据的配置,包括type/data_root/transforms/mode。
For data type, please refer to paddleseg.datasets.(数据类型的参考)
For specific transforms, please refer to paddleseg.transforms.transforms.(数据增强的参考)
val_dataset: 验证数据的配置,包括type/data_root/transforms/mode。
optimizer: 优化器的配置,请参考paddleseg.optimizers。
learning_rate: 学习率的配置。 如果有衰减的设置,learning_rate值代表初始学习率,目前仅支持多项式衰减poly decay.
decay power衰减率和end_lr最终学习率需要根据实验调整。
loss: 损失函数的配置,多种损失函数Multi-loss。
损失函数类型的顺序必须和分割模型的输出一致,其中coef项表示相应损失的权重,注意coef的数量必须和模型输出的数量一样。
如果在输出中使用相同的损失类型,则可能只有一种损失类型,否则损失类型的数量必须与coef的数量一致。
model: 模型的配置,包括type/backbone和model-dependent arguments.
模型类型model,参考paddleseg.models。
骨干网络类型backbone,参考paddleseg.models.backbones。
Args:
path (str) : config文件路径, 仅支持yaml格式。
opts (list, optional): 使用opts去更新所有选项的键值对key-value。
Examples:
from paddleseg.cvlibs.config import Config
# Create a cfg object with yaml file path.
cfg = Config(yaml_cfg_path)
# Parsing the argument when its property is used.
train_dataset = cfg.train_dataset
# the argument of model should be parsed after dataset,
# since the model builder uses some properties in dataset.
model = cfg.model
...
"""
def __init__(
self,
path: str,
learning_rate: Optional[float]=None,
batch_size: Optional[int]=None,
iters: Optional[int]=None,
opts: Optional[list]=None,
checker: Optional[checker.ConfigChecker]=None, ):
assert os.path.exists(path), \
'Config path ({}) does not exist'.format(path)
assert path.endswith('yml') or path.endswith('yaml'), \
'Config file ({}) should be yaml format'.format(path)
# 将yaml文件解析成字典dict
self.dic = self._parse_from_yaml(path)
# 根据传参来进行配置文件yaml字典dict的更新
self.dic = self.update_config_dict(
self.dic,
learning_rate=learning_rate,
batch_size=batch_size,
iters=iters,
opts=opts)
if checker is None:
checker = self._build_default_checker()
checker.apply_all_rules(self)
# 使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,
# 可以与所定义的属性配合使用,这样可以防止属性被修改。
@property
def batch_size(self) -> int:
return self.dic.get('batch_size')
@property
def iters(self) -> int:
return self.dic.get('iters')
@property
def to_static_training(self) -> bool:
return self.dic.get('to_static_training', False)
@property
def model_cfg(self) -> Dict:
return self.dic.get('model', {}).copy()
@property
def loss_cfg(self) -> Dict:
return self.dic.get('loss', {}).copy()
@property
def distill_loss_cfg(self) -> Dict:
return self.dic.get('distill_loss', {}).copy()
@property
def lr_scheduler_cfg(self) -> Dict:
return self.dic.get('lr_scheduler', {}).copy()
@property
def optimizer_cfg(self) -> Dict:
return self.dic.get('optimizer', {}).copy()
@property
def train_dataset_cfg(self) -> Dict:
return self.dic.get('train_dataset', {}).copy()
@property
def val_dataset_cfg(self) -> Dict:
return self.dic.get('val_dataset', {}).copy()
# TODO merge test_config into val_dataset
@property
def test_config(self) -> Dict:
return self.dic.get('test_config', {}).copy()
@classmethod
def update_config_dict(cls, dic: dict, *args, **kwargs) -> dict:
return update_config_dict(dic, *args, **kwargs)
# 在Python中,@classmethod装饰器用于将类中的方法声明为可以使用ClassName.MethodName()调用的类方法,
# 也可以使用类的对象调用类方法。
@classmethod
def _parse_from_yaml(cls, path: str, *args, **kwargs) -> dict:
return parse_from_yaml(path, *args, **kwargs)
@classmethod
def _build_default_checker(cls):
rules = []
rules.append(checker.DefaultPrimaryRule())
rules.append(checker.DefaultSyncNumClassesRule())
rules.append(checker.DefaultSyncImgChannelsRule())
# Losses
rules.append(checker.DefaultLossRule('loss'))
rules.append(checker.DefaultSyncIgnoreIndexRule('loss'))
# Distillation losses
rules.append(checker.DefaultLossRule('distill_loss'))
rules.append(checker.DefaultSyncIgnoreIndexRule('distill_loss'))
return checker.ConfigChecker(rules, allow_update=True)
def __str__(self) -> str:
# Use NoAliasDumper to avoid yml anchor
return yaml.dump(self.dic, Dumper=utils.NoAliasDumper)
def parse_from_yaml(path: str):
"""
递归地解析yaml文件并构建配置config
"""
# 读取yaml文件,并转为字典dict形式
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
if _BASE_KEY in dic:
# pop()方法删除字典给定键key所对应的值,返回值为被删除的值。
# 继承的基本路径
base_files = dic.pop(_BASE_KEY)
if isinstance(base_files, str):
base_files = [base_files]
# 对于继承的每一个base文件来说
for bf in base_files:
# os.path.dirname(path):去掉文件名,返回目录。
base_path = os.path.join(os.path.dirname(path), bf)
# 解析继承的_base_文件yaml
base_dic = parse_from_yaml(base_path)
# 更新字典
dic = merge_config_dicts(dic, base_dic)
return dic
def merge_config_dicts(dic, base_dic):
"""
将dic合并到base_dic
"""
base_dic = base_dic.copy() # 浅复制
dic = dic.copy() # 浅复制
# 判断dic是否继承
if not dic.get(_INHERIT_KEY, True):
dic.pop(_INHERIT_KEY)
return dic
# 循环遍历字典dic里面的key, value
for key, val in dic.items():
# 如果发现val是一个字典,并且key在base_dic里面,就继续更新。
if isinstance(val, dict) and key in base_dic:
base_dic[key] = merge_config_dicts(val, base_dic[key])
else:
base_dic[key] = val
return base_dic
def update_config_dict(dic: dict,
learning_rate: Optional[float]=None,
batch_size: Optional[int]=None,
iters: Optional[int]=None,
opts: Optional[list]=None):
"""Update config"""
# TODO: If the items to update are marked as anchors in the yaml file,
# we should synchronize the references.
dic = dic.copy()
if learning_rate:
dic['lr_scheduler']['learning_rate'] = learning_rate
if batch_size:
dic['batch_size'] = batch_size
if iters:
dic['iters'] = iters
if opts is not None:
for item in opts:
assert ('=' in item) and (len(item.split('=')) == 2), "--opts params should be key=value," \
" such as `--opts batch_size=1 test_config.scales=0.75,1.0,1.25`, " \
"but got ({})".format(opts)
key, value = item.split('=')
if isinstance(value, six.string_types):
try:
value = literal_eval(value)
except ValueError:
pass
except SyntaxError:
pass
key_list = key.split('.')
tmp_dic = dic
for subkey in key_list[:-1]:
assert subkey in tmp_dic, "Can not update {}, because it is not in config.".format(key)
tmp_dic = tmp_dic[subkey]
tmp_dic[key_list[-1]] = value
return dic
import copy
from typing import Any, Optional
import yaml
import paddle
from paddleseg.cvlibs import manager, Config
from paddleseg.utils import utils, logger
from paddleseg.utils.utils import CachedProperty as cached_property
class Builder(object):
"""
用于生成组件的基类
Args:
config (Config): Config类对象。
comp_list (list, optional): 组件类的列表。Default: None
"""
def __init__(self, config: Config, comp_list: Optional[list]=None):
super().__init__()
self.config = config
self.comp_list = comp_list
# {'type': 'MixedLoss', 'losses': [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}], 'coef': [0.4, 0.6]}
def build_component(self, cfg):
"""
Create Python object, such as model, loss, dataset, etc.
"""
# copy.copy()是浅拷贝,只拷贝父对象,不会拷贝对象的内部的子对象。
# copy.deepcopy()是深拷贝,会拷贝对象及其子对象,哪怕以后对其有改动,也不会影响其第一次的拷贝。
cfg = copy.deepcopy(cfg)
if 'type' not in cfg:
raise RuntimeError(
"It is not possible to create a component object from {}, as 'type' is not specified.".format(cfg)
)
# 类的类型
class_type = cfg.pop('type')
# 加载组件类
com_class = self.load_component_class(class_type)
# 参数字典
params = {}
for key, val in cfg.items():
if self.is_meta_type(val):
params[key] = self.build_component(val)
elif isinstance(val, list):
params[key] = [
self.build_component(item)
if self.is_meta_type(item) else item for item in val
]
else:
params[key] = val
# 组件类的实例化
try:
obj = self.build_component_impl(com_class, **params)
except Exception as e:
if hasattr(com_class, '__name__'):
com_name = com_class.__name__
else:
com_name = ''
raise RuntimeError(
f"Tried to create a {com_name} object, but the operation has failed. "
"Please double check the arguments used to create the object.\n"
f"The error message is: \n{str(e)}")
return obj
def build_component_impl(self, component_class, *args, **kwargs):
return component_class(*args, **kwargs)
def load_component_class(self, class_type):
for com in self.comp_list:
if class_type in com.components_dict:
return com[class_type]
raise RuntimeError("The specified component ({}) was not found.".format(class_type))
@classmethod
def is_meta_type(cls, obj):
# TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
# to make it more pythonic?
return isinstance(obj, dict) and 'type' in obj
@classmethod
def show_msg(cls, name, cfg):
msg = 'Use the following config to build {}\n'.format(name)
msg += str(yaml.dump({name: cfg}, Dumper=utils.NoAliasDumper))
logger.info(msg[0:-1])
class SegBuilder(Builder):
"""
此类负责构建用于语义分割的部件。
"""
def __init__(self, config, comp_list=None):
# 组件管理器列表
if comp_list is None:
comp_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS
]
super().__init__(config, comp_list)
# @cached_property缓存装饰器,使用cached_property修饰过的函数,变成了对象的属性。
# 当第一次引用该属性时,会调用该函数,以后再调用该属性时,会直接从字典中取。
@cached_property
def model(self) -> paddle.nn.Layer:
model_cfg = self.config.model_cfg
assert model_cfg != {}, 'No model specified in the configuration file.'
if self.config.train_dataset_cfg['type'] != 'Dataset':
# 检查并同步模型配置model config和数据集类dataset class中的num_classes
assert hasattr(self.train_dataset_class, 'NUM_CLASSES'), \
'If train_dataset class is not `Dataset`, it must have `NUM_CLASSES` attr.'
num_classes = getattr(self.train_dataset_class, 'NUM_CLASSES')
if 'num_classes' in model_cfg:
assert model_cfg['num_classes'] == num_classes, \
'The num_classes is not consistent for model config ({}) ' \
'and train_dataset class ({}) '.format(model_cfg['num_classes'], num_classes)
else:
logger.warning(
'Add the `num_classes` in train_dataset class to model config.'
'We suggest you manually set `num_classes` in model config.'
)
model_cfg['num_classes'] = num_classes
# 检查并同步模型配置model config和数据集类dataset class中的in_channels
assert hasattr(self.train_dataset_class, 'IMG_CHANNELS'), \
'If train_dataset class is not `Dataset`, it must have `IMG_CHANNELS` attr.'
in_channels = getattr(self.train_dataset_class, 'IMG_CHANNELS')
x = utils.get_in_channels(model_cfg)
if x is not None:
assert x == in_channels, \
'The in_channels in model config ({}) and the img_channels in train_dataset ' \
'class ({}) is not consistent'.format(x, in_channels)
else:
model_cfg = utils.set_in_channels(model_cfg, in_channels)
logger.warning(
'Add the `in_channels` in train_dataset class to model config.'
'We suggest you manually set `in_channels` in model config.'
)
# 信息打印
self.show_msg('model', model_cfg)
return self.build_component(model_cfg)
@cached_property
def optimizer(self) -> paddle.optimizer.Optimizer:
opt_cfg = self.config.optimizer_cfg
assert opt_cfg != {}, 'No optimizer specified in the configuration file.'
# For compatibility
if opt_cfg['type'] == 'adam':
opt_cfg['type'] = 'Adam'
if opt_cfg['type'] == 'sgd':
opt_cfg['type'] = 'SGD'
if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg:
opt_cfg['type'] = 'Momentum'
logger.info('If the type is SGD and momentum in optimizer config, '
'the type is changed to Momentum.')
self.show_msg('optimizer', opt_cfg)
opt = self.build_component(opt_cfg)
opt = opt(self.model, self.lr_scheduler)
return opt
@cached_property
def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler:
lr_cfg = self.config.lr_scheduler_cfg
assert lr_cfg != {}, 'No lr_scheduler specified in the configuration file.'
use_warmup = False
if 'warmup_iters' in lr_cfg:
use_warmup = True
warmup_iters = lr_cfg.pop('warmup_iters')
assert 'warmup_start_lr' in lr_cfg, \
"When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler"
warmup_start_lr = lr_cfg.pop('warmup_start_lr')
end_lr = lr_cfg['learning_rate']
lr_type = lr_cfg.pop('type')
if lr_type == 'PolynomialDecay':
iters = self.config.iters - warmup_iters if use_warmup else self.config.iters
iters = max(iters, 1)
lr_cfg.setdefault('decay_steps', iters)
try:
lr_sche = getattr(paddle.optimizer.lr, lr_type)(**lr_cfg)
except Exception as e:
raise RuntimeError(
"Create {} has failed. Please check lr_scheduler in config. "
"The error message: {}".format(lr_type, e))
if use_warmup:
lr_sche = paddle.optimizer.lr.LinearWarmup(
learning_rate=lr_sche,
warmup_steps=warmup_iters,
start_lr=warmup_start_lr,
end_lr=end_lr)
return lr_sche
@cached_property
def loss(self) -> dict:
loss_cfg = self.config.loss_cfg
assert loss_cfg != {}, 'No loss specified in the configuration file.'
return self._build_loss('loss', loss_cfg)
@cached_property
def distill_loss(self) -> dict:
loss_cfg = self.config.distill_loss_cfg
assert loss_cfg != {}, 'No distill_loss specified in the configuration file.'
return self._build_loss('distill_loss', loss_cfg)
def _build_loss(self, loss_name, loss_cfg: dict):
def _check_helper(loss_cfg, ignore_index):
if 'ignore_index' not in loss_cfg:
loss_cfg['ignore_index'] = ignore_index
logger.warning('Add the `ignore_index` in train_dataset class to {} config.' \
'We suggest you manually set `ignore_index` in {} config.'.format(loss_name, loss_name)
)
else:
assert loss_cfg['ignore_index'] == ignore_index, \
'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, '\
'train_dataset ignore_index = {}'.format(loss_cfg['ignore_index'], ignore_index)
# 检查并同步模型配置model config和数据集类dataset class中的ignore_index
if self.config.train_dataset_cfg['type'] != 'Dataset':
assert hasattr(self.train_dataset_class, 'IGNORE_INDEX'), \
'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.'
ignore_index = getattr(self.train_dataset_class, 'IGNORE_INDEX')
for loss_cfg_i in loss_cfg['types']:
if loss_cfg_i['type'] == 'MixedLoss':
# [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}]
for loss_cfg_j in loss_cfg_i['losses']:
_check_helper(loss_cfg_j, ignore_index)
else:
_check_helper(loss_cfg_i, ignore_index)
# 信息打印
self.show_msg(loss_name, loss_cfg)
loss_dict = {'coef': loss_cfg['coef'], "types": []}
# {'type': 'MixedLoss', 'losses': [{'type': 'CrossEntropyLoss'}, {'type': 'LovaszSoftmaxLoss'}], 'coef': [0.4, 0.6]}
for item in loss_cfg['types']:
loss_dict['types'].append(self.build_component(item))
return loss_dict
@cached_property
def train_dataset(self) -> paddle.io.Dataset:
dataset_cfg = self.config.train_dataset_cfg
assert dataset_cfg != {}, 'No train_dataset specified in the configuration file.'
self.show_msg('train_dataset', dataset_cfg)
dataset = self.build_component(dataset_cfg)
assert len(dataset) != 0, \
'The number of samples in train_dataset is 0. Please check whether the dataset is valid.'
return dataset
@cached_property
def val_dataset(self) -> paddle.io.Dataset:
dataset_cfg = self.config.val_dataset_cfg
assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
self.show_msg('val_dataset', dataset_cfg)
dataset = self.build_component(dataset_cfg)
if len(dataset) == 0:
logger.warning('The number of samples in val_dataset is 0. Please ensure this is the desired behavior.')
return dataset
@cached_property
def train_dataset_class(self) -> Any:
dataset_cfg = self.config.train_dataset_cfg
assert dataset_cfg != {}, 'No train_dataset specified in the configuration file.'
dataset_type = dataset_cfg.get('type')
return self.load_component_class(dataset_type)
@cached_property
def val_dataset_class(self) -> Any:
dataset_cfg = self.config.val_dataset_cfg
assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
dataset_type = dataset_cfg.get('type')
return self.load_component_class(dataset_type)
@cached_property
def val_transforms(self) -> list:
dataset_cfg = self.config.val_dataset_cfg
assert dataset_cfg != {}, 'No val_dataset specified in the configuration file.'
transforms = []
for item in dataset_cfg.get('transforms', []):
transforms.append(self.build_component(item))
return transforms
特别注意,这块具体实现的类,如class Cityscapes(Dataset)等,称为组件;
组件管理器,则为相应的模型model管理器、数据集datasets管理器等。
import inspect
from collections.abc import Sequence
import warnings
class ComponentManager:
"""
组件管理器类
实现管理器类以正确添加新的组件,组件可以被添加作为类或函数类型。
Args:
name (str): The name of component.
Returns:
A callable object of ComponentManager.
Examples 1:
from paddleseg.cvlibs.manager import ComponentManager
model_manager = ComponentManager()
class AlexNet: ...
class ResNet: ...
model_manager.add_component(AlexNet)
model_manager.add_component(ResNet)
# Or pass a sequence alliteratively:
model_manager.add_component([AlexNet, ResNet])
print(model_manager.components_dict)
# {'AlexNet': , 'ResNet': }
Examples 2:
# Or an easier way, using it as a Python decorator, while just add it above the class declaration.
from paddleseg.cvlibs.manager import ComponentManager
model_manager = ComponentManager()
@model_manager.add_component
class AlexNet: ...
@model_manager.add_component
class ResNet: ...
print(model_manager.components_dict)
# {'AlexNet': , 'ResNet': }
"""
def __init__(self, name=None):
self._components_dict = dict()
self._name = name
def __len__(self):
return len(self._components_dict)
def __repr__(self):
name_str = self._name if self._name else self.__class__.__name__
return "{}:{}".format(name_str, list(self._components_dict.keys()))
def __getitem__(self, item):
if item not in self._components_dict.keys():
raise KeyError("{} does not exist in availabel {}".format(item, self))
return self._components_dict[item]
@property
def components_dict(self):
return self._components_dict
@property
def name(self):
return self._name
def _add_single_component(self, component):
"""
将单个组件添加到相应的管理器中。(如,模型管理器)
Args:
component (function|class): A new component.
Raises:
TypeError: When `component` is neither class nor function.
KeyError: When `component` was added already.
"""
# 目前仅仅支持类class和函数function类型
if not (inspect.isclass(component) or inspect.isfunction(component)):
raise TypeError("Expect class/function type, but received {}".format(type(component)))
# 获取组件的内部名称
component_name = component.__name__
# 检查这个组件是否已经被添加
# 以组件的内部名称为键
if component_name in self._components_dict.keys():
warnings.warn("{} exists already! It is now updated to {} !!!".format(component_name, component))
self._components_dict[component_name] = component
else:
self._components_dict[component_name] = component
def add_component(self, components):
"""
将组件添加到相应的管理器中。
Args:
components (function|class|list|tuple): Support four types of components.
Returns:
components (function|class|list|tuple): Same with input components.
"""
# 判断这个组件components是否为序列
if isinstance(components, Sequence):
for component in components:
self._add_single_component(component)
else:
component = components
self._add_single_component(component)
return components
# 模型model管理器
MODELS = ComponentManager("models")
# 骨干网络backbone管理器
BACKBONES = ComponentManager("backbones")
# 数据集datasets管理器
DATASETS = ComponentManager("datasets")
# 数据增强transforms管理器
TRANSFORMS = ComponentManager("transforms")
# 损失函数losses管理器
LOSSES = ComponentManager("losses")
# 优化器optimizers管理器
OPTIMIZERS = ComponentManager("optimizers")