上一篇介绍了图像分割PaddleSeg套件的整体情况,并介绍了训练的入口文件train.py。在train.py文件中会对配置文件进行解析,获得训练参数。这一篇主要介绍,如何通过Config类对配置文件进行解析。
Config类定义在paddleseg/cvlibs/config.py文件中。它保存了数据集配置、模型配置、主干网络的配置、损失函数配置等所有的超参数。
在PaddleSeg中,通过使用YAML文件的方式保存配置。该方法的好处是,只需要对YAML进行修改,或者创建新的YAML文件就可以新建一个训练任务。
YAML的语法比较简单,文件结构也很方便阅读,下面我们从图像分割最基础的FCN网络的配置文件开始了解一下如何从YAML文件生成Config对象。
举个例子,看一下dygraph/configs/fcn/fcn_hrnetw18_cityscapes_1024x512_80k.yml文件内容:
# _base_ 不是必须的,其作用更像基类。
# _base_指定的文件可以保存通用的配置,避免相同配置重复书写。若存在相同配置,会覆盖_base_指定yml文件的配置。
_base_: '../_base_/cityscapes.yml'
#模型信息
model:
#模型的类型FCN
type: FCN
#使用的主干网络为HRNet
backbone:
type: HRNet_W18
#主干网络的预训练模型的下载地址。
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
#模型分类数为19,可根据实际情况修改
num_classes: 19
#模型的预训练地址,这里为空
pretrained: Null
#这个是创建模型时需要传入的参数,该参数可以根据具体模型情况进行自定义设置,这个结合模型在具体讲解。
backbone_indices: [-1]
#优化器设置,这里只设置了正则化的衰减系数,原因是因为在base里面已经设置了优化器的名称和学习率。
optimizer:
weight_decay: 0.0005
#总迭代次数为80000次。
iters: 80000
下面在看一下cityscape.yml文件内容:
#如果fcn的配置文件,配置了相同内容会覆盖本配置内容。
batch_size: 4
#迭代次数
iters: 80000
#训练集配置
train_dataset:
#类型为Cityscapes,这里的type对应的值会在Config类中实例化具体的对象,所以名字要跟类名一致。
#Citycapes类保存在dygraph/paddleseg/datasets/cityscapes.py文件中
type: Cityscapes
#指定数据集的根目录,这里没有指定具体的文件List,是因为list是在Cityscape类中生成的。
dataset_root: data/cityscapes
#数据增强操作
transforms:
#每一个type 则代表了一个数据增强操作对应的类名。下面的值则为创建对象需要传递的参数。
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize
#模式为训练模式
mode: train
#验证集配置
val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
#模式为验证集模式
mode: val
#优化器设置。
optimizer:
#优化器为SGG
type: sgd
#动量
momentum: 0.9
#正则化
weight_decay: 4.0e-5
#学习率设置
learning_rate:
#学习率
value: 0.01
#学习率衰减策略
decay:
type: poly
power: 0.9
end_lr: 0.0
#损失函数设置
loss:
types:
#支持多种损失函数
- type: CrossEntropyLoss
#损失权重,若包含多个损失函数,可以在此处设置权重,权重数量需要与损失函数数量一致。
coef: [1]
上面介绍了yml配置文件的内容,下面解读Config类如何将yml文件转换为对象。Config代码比较长,下面截取重要的方法进行解读。
Config类的构造方法:
def __init__(self,
path: str,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
#path为yml文件的路径,若果没有指定路径则抛出异常。
if not path:
raise ValueError('Please specify the configuration file path.')
#还需要判断路径是否存在,如果不存在则抛出异常。
if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path))
#初始化成员变量,模型对象和损失函数对象。
self._model = None
self._losses = None
#判断配置文件类型是否为YAML。
if path.endswith('yml') or path.endswith('yaml'):
#如果文件类型正确,则通过_parse_from_yaml方法将文件内容保存到字典中。
self.dic = self._parse_from_yaml(path)
else:
raise RuntimeError('Config file should in yaml format!')
#更新配置中的learning_rate、batch_size和iters三个参数,这个三个参数是通过命令行传递过来的,
#优先级高于yaml配置,会覆盖配置文件中的配置。
self.update(
learning_rate=learning_rate, batch_size=batch_size, iters=iters)
下面看一下在构造函数中遇到的_parse_from_yaml方法的源代码:
def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config'''
#首先打开配置文件,通过yaml库中的load方法转换为字典。yaml为第三方库,可以同pip安装。具体使用方法参考yaml相关文档。
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
#判断_base_是否在字典中,本次使用的FCN的配置文件是包含的也就是上面讲解的cityscape.yml文件。
if '_base_' in dic:
#同样获取cityscape.yml的路径然后通过本方法获取base配置的字典。
cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_')
base_path = os.path.join(cfg_dir, base_path)
#递归调用,因为cityscape.yml中并不包含_base_,所以下面的方法就不会执行到现在这部分代码。
base_dic = self._parse_from_yaml(base_path)
#更新dic字典中的内容。
dic = self._update_dic(dic, base_dic)
return dic
下面在讲解一下构造函数中的update方法,这个方法比较简单就是更新learning rate、batch size和iters。
def update(self,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
'''Update config'''
#如果learning_rate存在,更新字典中的值。
if learning_rate:
self.dic['learning_rate']['value'] = learning_rate
#更新batch_size
if batch_size:
self.dic['batch_size'] = batch_size
#更新iters。
if iters:
self.dic['iters'] = iters
在_parse_from_yaml中调用_update_dic方法更新字典参数,我们看一下与上面update的区别。
def _update_dic(self, dic, base_dic):
"""
Update config from dic based base_dic
"""
#首先复制一个basc_dic
base_dic = base_dic.copy()
#遍历dic中的键值对。
for key, val in dic.items():
#如果dic中的值的类型为字典,同时这个键在base_dic中存在,则需要使用base_dic中值进行更新。
#递归调用本方法进行更新,直到val类型是基本类型。
if isinstance(val, dict) and key in base_dic:
base_dic[key] = self._update_dic(val, base_dic[key])
#如果是基本类型则直接更新,上面递归到此处会停止,在下面return处直接返回。
else:
base_dic[key] = val
dic = base_dic
return dic
Config类中还包含了很多以@property为注解的方法,对应了yaml配置文件中的train_dataset、val_dataset、model、loss等配置。前面提到过在这些配置中都会包含一个名字是type的键,它对应的值为类的名字。以property为注解的方法则会通过类的名字创建该对象,并将该对象返回给用户,此处使用的是懒加载的方式,只有当被调用的时候才会去创建。下面我们举例model属性来讲解一下,其他属性工作流程类似。
@property
def model(self) -> paddle.nn.Layer:
#从Config的配置字典中获取model的配置内容对应yaml文件中的部分如下:
#model:
#type: FCN
#backbone:
# type: HRNet_W18
# pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
#num_classes: 19
#pretrained: Null
#backbone_indices: [-1]
model_cfg = self.dic.get('model').copy()
#使用train_dataset配置中的类别数量覆盖model中的配置
model_cfg['num_classes'] = self.train_dataset.num_classes
#如果model_cfg 不存在则抛出异常
if not model_cfg:
raise RuntimeError('No model specified in the configuration file.')
#在构造函数中_model配置为None,这里只创建一次模型对象。
if not self._model:
#创建模型对象。下面会继续解读_load_object方法。
self._model = self._load_object(model_cfg)
return self._model
_load_object方法解读:
def _load_object(self, cfg: dict) -> Any:
#拷贝一份配置,因为需要通过type的值创建对象,所以如果cfg中不包含type键则会抛出异常。
cfg = cfg.copy()
if 'type' not in cfg:
raise RuntimeError('No object information in {}.'.format(cfg))
#通过_load_component方法根据type的值获取类组件,这里的组件都是在定义各个类的时候通过
#装饰器添加到manager维护的List中的,所以这里可以直接获取。至于如何加入list会在第3节接触到。
component = self._load_component(cfg.pop('type'))
#此处获取创建对象需要传递的参数,保存在params中。
params = {}
#遍历cfg中的键值对。
for key, val in cfg.items():
#这里使用_is_meta_type方法来判断val是字典同时也包含type值,如果包含的的话说明val对应的也是一个对象,
#需要使用递归的方式获取到,直到参数类型为简单对象。
if self._is_meta_type(val):
params[key] = self._load_object(val)
#如果参数是一个列表,则需要遍历列表中的内容,判断是否需要递归创建对象。
elif isinstance(val, list):
params[key] = [
self._load_object(item)
if self._is_meta_type(item) else item for item in val
]
#遇到基本类型,保存参数。
else:
params[key] = val
#遍历借宿创建对象。
return component(**params)
至此Config类代码就解读完毕。
PaddleSeg仓库地址:https://github.com/PaddlePaddle/PaddleSeg