图像分割套件PaddleSeg全面解析(二)Config代码解读

上一篇介绍了图像分割PaddleSeg套件的整体情况,并介绍了训练的入口文件train.py。在train.py文件中会对配置文件进行解析,获得训练参数。这一篇主要介绍,如何通过Config类对配置文件进行解析。

Config类定义在paddleseg/cvlibs/config.py文件中。它保存了数据集配置、模型配置、主干网络的配置、损失函数配置等所有的超参数。

在PaddleSeg中,通过使用YAML文件的方式保存配置。该方法的好处是,只需要对YAML进行修改,或者创建新的YAML文件就可以新建一个训练任务。

YAML的语法比较简单,文件结构也很方便阅读,下面我们从图像分割最基础的FCN网络的配置文件开始了解一下如何从YAML文件生成Config对象。

举个例子,看一下dygraph/configs/fcn/fcn_hrnetw18_cityscapes_1024x512_80k.yml文件内容:

# _base_ 不是必须的,其作用更像基类。
# _base_指定的文件可以保存通用的配置,避免相同配置重复书写。若存在相同配置,会覆盖_base_指定yml文件的配置。
_base_: '../_base_/cityscapes.yml'

#模型信息
model:
  #模型的类型FCN
  type: FCN
  #使用的主干网络为HRNet 
  backbone:
    type: HRNet_W18
    #主干网络的预训练模型的下载地址。
    pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
  #模型分类数为19,可根据实际情况修改
  num_classes: 19
  #模型的预训练地址,这里为空
  pretrained: Null
  #这个是创建模型时需要传入的参数,该参数可以根据具体模型情况进行自定义设置,这个结合模型在具体讲解。
  backbone_indices: [-1]

#优化器设置,这里只设置了正则化的衰减系数,原因是因为在base里面已经设置了优化器的名称和学习率。
optimizer:
  weight_decay: 0.0005
#总迭代次数为80000次。
iters: 80000

下面在看一下cityscape.yml文件内容:

#如果fcn的配置文件,配置了相同内容会覆盖本配置内容。
batch_size: 4
#迭代次数
iters: 80000
#训练集配置
train_dataset:
  #类型为Cityscapes,这里的type对应的值会在Config类中实例化具体的对象,所以名字要跟类名一致。
  #Citycapes类保存在dygraph/paddleseg/datasets/cityscapes.py文件中
  type: Cityscapes
  #指定数据集的根目录,这里没有指定具体的文件List,是因为list是在Cityscape类中生成的。
  dataset_root: data/cityscapes
  #数据增强操作
  transforms:
  #每一个type 则代表了一个数据增强操作对应的类名。下面的值则为创建对象需要传递的参数。
    - type: ResizeStepScaling
      min_scale_factor: 0.5
      max_scale_factor: 2.0
      scale_step_size: 0.25
    - type: RandomPaddingCrop
      crop_size: [1024, 512]
    - type: RandomHorizontalFlip
    - type: Normalize
  #模式为训练模式
  mode: train
#验证集配置
val_dataset:
  type: Cityscapes
  dataset_root: data/cityscapes
  transforms:
    - type: Normalize
  #模式为验证集模式
  mode: val

#优化器设置。
optimizer:
  #优化器为SGG
  type: sgd
  #动量
  momentum: 0.9
  #正则化
  weight_decay: 4.0e-5

#学习率设置
learning_rate:
  #学习率
  value: 0.01
  #学习率衰减策略
  decay:
    type: poly
    power: 0.9
    end_lr: 0.0
#损失函数设置
loss:
  types:
    #支持多种损失函数
    - type: CrossEntropyLoss
  #损失权重,若包含多个损失函数,可以在此处设置权重,权重数量需要与损失函数数量一致。
  coef: [1]

上面介绍了yml配置文件的内容,下面解读Config类如何将yml文件转换为对象。Config代码比较长,下面截取重要的方法进行解读。

Config类的构造方法:

  def __init__(self,
               path: str,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
      #path为yml文件的路径,若果没有指定路径则抛出异常。
      if not path:
          raise ValueError('Please specify the configuration file path.')
      #还需要判断路径是否存在,如果不存在则抛出异常。
      if not os.path.exists(path):
          raise FileNotFoundError('File {} does not exist'.format(path))
      #初始化成员变量,模型对象和损失函数对象。
      self._model = None
      self._losses = None
      #判断配置文件类型是否为YAML。
      if path.endswith('yml') or path.endswith('yaml'):
          #如果文件类型正确,则通过_parse_from_yaml方法将文件内容保存到字典中。
          self.dic = self._parse_from_yaml(path)
      else:
          raise RuntimeError('Config file should in yaml format!')
      #更新配置中的learning_rate、batch_size和iters三个参数,这个三个参数是通过命令行传递过来的,
      #优先级高于yaml配置,会覆盖配置文件中的配置。
      self.update(
          learning_rate=learning_rate, batch_size=batch_size, iters=iters)

下面看一下在构造函数中遇到的_parse_from_yaml方法的源代码:

    def _parse_from_yaml(self, path: str):
        '''Parse a yaml file and build config'''
        #首先打开配置文件,通过yaml库中的load方法转换为字典。yaml为第三方库,可以同pip安装。具体使用方法参考yaml相关文档。
        with codecs.open(path, 'r', 'utf-8') as file:
            dic = yaml.load(file, Loader=yaml.FullLoader)
		#判断_base_是否在字典中,本次使用的FCN的配置文件是包含的也就是上面讲解的cityscape.yml文件。
        if '_base_' in dic:
            #同样获取cityscape.yml的路径然后通过本方法获取base配置的字典。
            cfg_dir = os.path.dirname(path)
            base_path = dic.pop('_base_')
            base_path = os.path.join(cfg_dir, base_path)
            #递归调用,因为cityscape.yml中并不包含_base_,所以下面的方法就不会执行到现在这部分代码。
            base_dic = self._parse_from_yaml(base_path)
            #更新dic字典中的内容。
            dic = self._update_dic(dic, base_dic)
        return dic

下面在讲解一下构造函数中的update方法,这个方法比较简单就是更新learning rate、batch size和iters。

    def update(self,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
        '''Update config'''
        #如果learning_rate存在,更新字典中的值。
        if learning_rate:
            self.dic['learning_rate']['value'] = learning_rate
        #更新batch_size
        if batch_size:
            self.dic['batch_size'] = batch_size
        #更新iters。
        if iters:
            self.dic['iters'] = iters

在_parse_from_yaml中调用_update_dic方法更新字典参数,我们看一下与上面update的区别。

    def _update_dic(self, dic, base_dic):
        """
        Update config from dic based base_dic
        """
        #首先复制一个basc_dic
        base_dic = base_dic.copy()
        #遍历dic中的键值对。
        for key, val in dic.items():
        	#如果dic中的值的类型为字典,同时这个键在base_dic中存在,则需要使用base_dic中值进行更新。
            #递归调用本方法进行更新,直到val类型是基本类型。
            if isinstance(val, dict) and key in base_dic:
                base_dic[key] = self._update_dic(val, base_dic[key])
            #如果是基本类型则直接更新,上面递归到此处会停止,在下面return处直接返回。
            else:
                base_dic[key] = val
        dic = base_dic
        return dic

Config类中还包含了很多以@property为注解的方法,对应了yaml配置文件中的train_dataset、val_dataset、model、loss等配置。前面提到过在这些配置中都会包含一个名字是type的键,它对应的值为类的名字。以property为注解的方法则会通过类的名字创建该对象,并将该对象返回给用户,此处使用的是懒加载的方式,只有当被调用的时候才会去创建。下面我们举例model属性来讲解一下,其他属性工作流程类似。

@property
  def model(self) -> paddle.nn.Layer:
      #从Config的配置字典中获取model的配置内容对应yaml文件中的部分如下:
      #model:
  	  #type: FCN
      #backbone:
      #		type: HRNet_W18
      #		pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
  	  #num_classes: 19
  	  #pretrained: Null
      #backbone_indices: [-1]
      
      model_cfg = self.dic.get('model').copy()
      #使用train_dataset配置中的类别数量覆盖model中的配置
      model_cfg['num_classes'] = self.train_dataset.num_classes
      #如果model_cfg 不存在则抛出异常
      if not model_cfg:
          raise RuntimeError('No model specified in the configuration file.')
      #在构造函数中_model配置为None,这里只创建一次模型对象。
      if not self._model:
          #创建模型对象。下面会继续解读_load_object方法。
          self._model = self._load_object(model_cfg)
      return self._model

_load_object方法解读:

 def _load_object(self, cfg: dict) -> Any:
 		#拷贝一份配置,因为需要通过type的值创建对象,所以如果cfg中不包含type键则会抛出异常。
        cfg = cfg.copy()
        if 'type' not in cfg:
            raise RuntimeError('No object information in {}.'.format(cfg))
        #通过_load_component方法根据type的值获取类组件,这里的组件都是在定义各个类的时候通过
        #装饰器添加到manager维护的List中的,所以这里可以直接获取。至于如何加入list会在第3节接触到。
        component = self._load_component(cfg.pop('type'))
		#此处获取创建对象需要传递的参数,保存在params中。
        params = {}
        #遍历cfg中的键值对。
        for key, val in cfg.items():
            #这里使用_is_meta_type方法来判断val是字典同时也包含type值,如果包含的的话说明val对应的也是一个对象,
            #需要使用递归的方式获取到,直到参数类型为简单对象。
            if self._is_meta_type(val):
                params[key] = self._load_object(val)
            #如果参数是一个列表,则需要遍历列表中的内容,判断是否需要递归创建对象。
            elif isinstance(val, list):
                params[key] = [
                    self._load_object(item)
                    if self._is_meta_type(item) else item for item in val
                ]
            #遇到基本类型,保存参数。
            else:
                params[key] = val
		#遍历借宿创建对象。
        return component(**params)

至此Config类代码就解读完毕。

PaddleSeg仓库地址:https://github.com/PaddlePaddle/PaddleSeg

你可能感兴趣的:(深度学习,深度学习,神经网络)