Caffe2 - (二十) Detectron 之 config.py 文件参数

Caffe2 - (二十) Detectron 之 config.py 文件参数

config.py 给出了 Detectron 的默认参数,其位于 lib/core/config.py. 类似于 Faster R-CNN 中对应的形式.

一般不更改该文件参数,可以根据设置 yaml 文件并利用 merge_cfg_from_file (yaml_file) 来加载自定义参数,同时覆盖 config.py 内的默认参数.

设定 --cfg 参数,即可指定 yaml 参数文件,参数是以 (key, value) 对的形式.

Detectron 给出的 yaml 文件位于 configs/ 目录及其子目录中.

config.py 内参数理解:

# 工具包导入
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from ast import literal_eval
from past.builtins import basestring
from utils.collections import AttrDict
import copy
import logging
import numpy as np
import os
import os.path as osp
import yaml

from utils.io import cache_url 

logger = logging.getLogger(__name__)

__C = AttrDict()

# 导入 config 参数命令:
#   from core.config import cfg
cfg = __C

# 注:避免使用 '.ON' 作为参数配置中的 key. 因为 yaml 会将其转化为 True. 可使用 'ENABLED' 取代 '.ON'.

1. 训练参数 Training options

# ---------------------------------------------------------------------------- #
# 训练默认参数
# ---------------------------------------------------------------------------- #
__C.TRAIN = AttrDict()

# 从 .pkl 文件初始化网络权重参数;.pkl 可以是预训练模型.
__C.TRAIN.WEIGHTS = b''

# 训练数据集 Datasets
# 训练数据集列表设定:datasets.dataset_catalog.DATASETS.keys()
# 如果设定了多个 datasets,则会在其并集上进行模型训练.
__C.TRAIN.DATASETS = ()

# 训练所采用的缩放尺度 Scales
# 每一个 scale 是图像短边的像素值
# 如果给定了多个 scales 值,则对于每个训练图片随机选取一个 scale,如尺度抖动数据增强scale jitter data augmentation
__C.TRAIN.SCALES = (600, )

# 缩放后输入图片最长边的最大像素值
__C.TRAIN.MAX_SIZE = 1000

# 训练 mini-batch 每张 GPU 的图片数
# 每个 mini-batch 的总图片数 = 每张 GPU 的图片数 * GPUs 数
# Total images per minibatch = TRAIN.IMS_PER_BATCH * NUM_GPUS
__C.TRAIN.IMS_PER_BATCH = 2

# 每张图片的 RoI mini-batch,即每张图片的 RoIs 数.
# 每张训练 mini-batch 中总的 RoIs 数 = 每张图片 mini-batch RoIs 数 * mini-batch 图片数 * GPUs 数
# 典型配置: 512 * 2 * 8 = 8192
# RoI minibatch size *per image* (number of regions of interest [ROIs])
# Total number of RoIs per training minibatch =
#   TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS
__C.TRAIN.BATCH_SIZE_PER_IM = 64

# mini-batch 中被标记为 foreground RoIs(i.e. class > 0) 的目标分数Target fraction 
__C.TRAIN.FG_FRACTION = 0.25

# RoI 的重叠区域大于 FG_THRESH 则被标记为 foreground
__C.TRAIN.FG_THRESH = 0.5

# RoI 的重叠区域在 [LO, HI] 区间内则被标记为 background (i.e. class = 0)
__C.TRAIN.BG_THRESH_HI = 0.5
__C.TRAIN.BG_THRESH_LO = 0.0

# 训练是否水平翻转图片
__C.TRAIN.USE_FLIPPED = True

# 如果 RoI 和 groundtruth box 的重叠区域大于阈值BBOX_THRESH,则(RoI gt_box)对作为边界框 bounding-box 回归训练样本.
__C.TRAIN.BBOX_THRESH = 0.5

# 模型保存周期,即多少次迭代进行一次模型断点保存.
# 需要除以 GPUs 数 NUM_GPUS,e.g., 20000/8 => 2500 iters
__C.TRAIN.SNAPSHOT_ITERS = 20000

# 训练采用指定的 proposals 
# 训练过程中,所有的 proposals 是在 proposal 文件中指定的.
# proposals 文件与 TRAIN.DATASETS 数据集相对应.
__C.TRAIN.PROPOSAL_FILES = ()

# 确保图片 mini-batches 具有相同的长宽比,(i.e. both tall and thin or both short and wide)
# 对于节省内存很重要,可以稍微加快训练.
__C.TRAIN.ASPECT_GROUPING = True

2. RPN 训练参数 RPN training options

# ---------------------------------------------------------------------------- #
# RPN 训练默认参数
# ---------------------------------------------------------------------------- #

# 如果 anchor 和 groundtruth box 的最小重叠区域大于阈值 RPN_POSITIVE_OVERLAP, 
# 则 (anchor, gt_box) 对作为 positive 训练样本
# (IOU >= thresh ==> positive RPN example)
__C.TRAIN.RPN_POSITIVE_OVERLAP = 0.7

# 如果 anchor 和 groundtruth box 的最大重叠区域小于阈值 RPN_NEGATIVE_OVERLAP, 
# 则 (anchor, gt_box) 对作为 negative 训练样本
# (IOU < thresh ==> negative RPN example)
__C.TRAIN.RPN_NEGATIVE_OVERLAP = 0.3

# 每个 RPN mini-batch 中被标记为 foreground (positive) 样本的目标分数Target fraction
__C.TRAIN.RPN_FG_FRACTION = 0.5

# 每张图片的 RPN 样本总数
__C.TRAIN.RPN_BATCH_SIZE_PER_IM = 256

# RPN proposals 所采用的 NMS 阈值 (end-to-end training with RPN 时使用)
__C.TRAIN.RPN_NMS_THRESH = 0.7

# NMS 处理前,top 分数的 RPN proposals 数
# When FPN is used, this is *per FPN level* (not total)
__C.TRAIN.RPN_PRE_NMS_TOP_N = 12000

# NMS 处理后,保留的 top 分数的 RPN proposals 数
# 所产生的 RPN proposals 总数(FPN 和 non-FPN 一样)
__C.TRAIN.RPN_POST_NMS_TOP_N = 2000

# 设定阈值像素值 RPN_STRADDLE_THRESH,丢弃超出图片边界的 PRN anchors
# 设定 RPN_STRADDLE_THRESH = -1 或 RPN_STRADDLE_THRESH = Large_Value(e.g. 100000),则不进行 anchors 裁剪.
__C.TRAIN.RPN_STRADDLE_THRESH = 0

# proposal 的 height 和 width 需要同时大于阈值RPN_MIN_SIZE
# (相对于原始图片尺度,不是训练或测试时的尺度)
__C.TRAIN.RPN_MIN_SIZE = 0

# 根据阈值CROWD_FILTER_THRESH 过滤在 crowd 区域的 proposals.
# "Inside" 的度量:proposal-with-crowd 交叉区域面积除以 proposal面积.
# "Inside" is measured as: proposal-with-crowd intersection area divided by proposal area.
__C.TRAIN.CROWD_FILTER_THRESH = 0.7

# 忽略面积小于阈值GT_MIN_AREA 的 groundtruth 物体
__C.TRAIN.GT_MIN_AREA = -1

# 如果FREEZE_CONV_BODY设定为 True,则冻结骨干backbone网络结构参数
__C.TRAIN.FREEZE_CONV_BODY = False

# 设定AUTO_RESUME=True 时,从输出路径中的最近模型断点snapshot 恢复训练
__C.TRAIN.AUTO_RESUME = True

3. 数据加载参数 Data loader options

# ---------------------------------------------------------------------------- #
# 数据加载默认参数
# ---------------------------------------------------------------------------- #
__C.DATA_LOADER = AttrDict()

# 数据加载所用的 Python 线程数threads
# (注:如果使用过多的线程,会出现 GIL 锁,导致训练变慢,实验发现 4 线程最佳.)
__C.DATA_LOADER.NUM_THREADS = 4

4. 推断/测试参数 Inference (‘test’) options

# ---------------------------------------------------------------------------- #
# 推断/测试默认参数
# ---------------------------------------------------------------------------- #
__C.TEST = AttrDict()

# 从 .pkl 文件初始化网络权重参数
__C.TEST.WEIGHTS = b''

# 测试数据集
# 类似于训练数据集,可用数据集列表设置: datasets.dataset_catalog.DATASETS.keys()
# 如果有多个数据集,则依次在每个数据集上进行测试.
__C.TEST.DATASETS = ()

# 测试时所采用的缩放尺度
# 每个缩放尺度是图片短边的像素值
# 如果给定多个缩放尺度,则所有的缩放尺度都在 multiscale inference 中使用.
__C.TEST.SCALES = (600, )

# 尺度缩放后输入图片的长边最大像素值
__C.TEST.MAX_SIZE = 1000

# NMS 的重叠阈值(suppress boxes with IoU >= this threshold)
__C.TEST.NMS = 0.3

# 如果 BBOX_REG 设定为 True,则采用类似于 Faster R-CNN bounding-box 回归的形式.
__C.TEST.BBOX_REG = True

# 利用给定 proposals 文件进行测试,(必须与TEST.DATASETS 相对应)
__C.TEST.PROPOSAL_FILES = ()

# 推断时每张图片的 proposals 数
__C.TEST.PROPOSAL_LIMIT = 2000

# RPN proposals 采用的 NMS 阈值.
__C.TEST.RPN_NMS_THRESH = 0.7

# NMS 处理前,top 分数的 RPN proposals 数
# When FPN is used, this is *per FPN level* (not total)
__C.TEST.RPN_PRE_NMS_TOP_N = 12000


# NMS 处理后,保留的 top 分数的 RPN proposals 数
# 所产生的 RPN proposals 总数(FPN 和 non-FPN 一样)
__C.TEST.RPN_POST_NMS_TOP_N = 2000

# proposal 的 height 和 width 需要同时大于阈值RPN_MIN_SIZE
# (相对于原始图片尺度,不是训练或测试时的尺度)
__C.TEST.RPN_MIN_SIZE = 0

# 每张图片返回的检测结果的最大数量(100 是根据 COCO 数据集设定的)
__C.TEST.DETECTIONS_PER_IM = 100

# 最小分数阈值,(假设分数已经归一化到 [0,1] 范围)
# 选定一个阈值,平衡 recall 和 precision,
# 高recal 且不太低 precision 的检测结果会减慢推断的后处理过程(like NMS)
__C.TEST.SCORE_THRESH = 0.05

# 如果 COMPETITION_MODE 设为 True,则保存检测结果. 竞赛模式
# 如果 COMPETITION_MODE 设为 False,则运行后不会保存结果文件(可能是比较大的文件)
__C.TEST.COMPETITION_MODE = True

# 利用 COCO json 数据集的 eval code 来评估检测结果,即时不是 COCO 数据集的评估 code.
# (e.g. 采用 COCO API 以 COCO 形式的 AP 来评估 PASCAL VOC 结果)
__C.TEST.FORCE_JSON_DATASET_EVAL = False

# [推断值,一般不在 config 文件直接设定]
# 如果 PRECOMPUTED_PROPOSALS = True,则表示使用预先计算的 proposals.
# 在 1-stage models and 2-stage models with RPN subnetwork 中设为 False.
__C.TEST.PRECOMPUTED_PROPOSALS = True

# [Inferred value; do not set directly in a config]
# Active dataset to test on
__C.TEST.DATASET = b''

# [推断值,一般不在 config 文件直接设定]
# 激活 proposal 文件,以便使用.
__C.TEST.PROPOSAL_FILE = b''

5. 边界框检测测试时间增强 Test-time augmentations for bounding box detection

# ---------------------------------------------------------------------------- #
# 边界框检测测试时间增强 Test-time augmentations for bounding box detection
# 例如 configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_2x.yaml for an example
# ---------------------------------------------------------------------------- #
__C.TEST.BBOX_AUG = AttrDict()

# True,开启边界框检测测试时间增强
__C.TEST.BBOX_AUG.ENABLED = False

# 用于组合预测box分数的启发式
#   可选参数: ('ID', 'AVG', 'UNION')
__C.TEST.BBOX_AUG.SCORE_HEUR = b'UNION'

# 用于组合预测box坐标的启发式
#   可选参数: ('ID', 'AVG', 'UNION')
__C.TEST.BBOX_AUG.COORD_HEUR = b'UNION'

# 以原缩放尺度scale 水平翻转(id transform)
__C.TEST.BBOX_AUG.H_FLIP = False

# 每个缩放尺度scale 是图片短边的像素尺寸
__C.TEST.BBOX_AUG.SCALES = ()

# 长边的最大像素尺寸
__C.TEST.BBOX_AUG.MAX_SIZE = 4000

# 在每个缩放尺度scale 水平翻转
__C.TEST.BBOX_AUG.SCALE_H_FLIP = False

# 基于物体object 尺寸应用缩放scaling
__C.TEST.BBOX_AUG.SCALE_SIZE_DEP = False
__C.TEST.BBOX_AUG.AREA_TH_LO = 50**2
__C.TEST.BBOX_AUG.AREA_TH_HI = 180**2

# 相对于图片宽度的每个长宽比aspect ratio
__C.TEST.BBOX_AUG.ASPECT_RATIOS = ()

# 在每个长宽比aspect ratio 水平翻转
__C.TEST.BBOX_AUG.ASPECT_RATIO_H_FLIP = False

6. Mask检测的测试时间增强Test-time augmentations for mask detection

# ---------------------------------------------------------------------------- #
# Mask检测的测试时间增强Test-time augmentations for mask detection
# 例如 configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_2x.yaml for an example
# ---------------------------------------------------------------------------- #
__C.TEST.MASK_AUG = AttrDict()

# True,开启实例分割的测试时间增强 instance mask detection
__C.TEST.MASK_AUG.ENABLED = False

# 用于组合 maks 检测的启发式
#   可选参数: ('SOFT_AVG', 'SOFT_MAX', 'LOGIT_AVG')
#   SOFT 前缀表示在 soft masks 进行计算
__C.TEST.MASK_AUG.HEUR = b'SOFT_AVG'

# 在原缩放尺度scale 水平翻转(id transform)
__C.TEST.MASK_AUG.H_FLIP = False

# 每个缩放尺度scale 是图片短边的像素尺寸
__C.TEST.MASK_AUG.SCALES = ()

# 长边的最大像素尺寸
__C.TEST.MASK_AUG.MAX_SIZE = 4000

# 在每个缩放尺度scale 水平翻转
__C.TEST.MASK_AUG.SCALE_H_FLIP = False

# 基于物体object 尺寸应用缩放尺度scale
__C.TEST.MASK_AUG.SCALE_SIZE_DEP = False
__C.TEST.MASK_AUG.AREA_TH = 180**2

# 每个长宽比相对于图片宽度
__C.TEST.MASK_AUG.ASPECT_RATIOS = ()

# 在每个长宽比aspect ratio 水平翻转
__C.TEST.MASK_AUG.ASPECT_RATIO_H_FLIP = False

7. 关键点检测测试增强Test-augmentations for keypoints detection

# ---------------------------------------------------------------------------- #
# 关键点检测测试增强Test-augmentations for keypoints detection
# 如 configs/test_time_aug/keypoint_rcnn_R-50-FPN_1x.yaml
# ---------------------------------------------------------------------------- #
__C.TEST.KPS_AUG = AttrDict()

# True,开启关键点检测测试时间增强
__C.TEST.KPS_AUG.ENABLED = False

# 用于组合关键点预测的启发式
#   可选参数: ('HM_AVG', 'HM_MAX')
__C.TEST.KPS_AUG.HEUR = b'HM_AVG'

# 在原始缩放尺度scale 水平翻转(id transform)
__C.TEST.KPS_AUG.H_FLIP = False

# 每个缩放尺度scale 是图片短边的像素尺寸
__C.TEST.KPS_AUG.SCALES = ()

# 长边的最大像素尺寸
__C.TEST.KPS_AUG.MAX_SIZE = 4000

# 在每个缩放尺度scale 水平翻转
__C.TEST.KPS_AUG.SCALE_H_FLIP = False

# 基于物体object 尺寸应用缩放尺度scaling
__C.TEST.KPS_AUG.SCALE_SIZE_DEP = False
__C.TEST.KPS_AUG.AREA_TH = 180**2

# 每个长宽比aspect ratio相对于图片宽带
__C.TEST.KPS_AUG.ASPECT_RATIOS = ()

# 在每个长宽比aspect ratio 水平翻转
__C.TEST.KPS_AUG.ASPECT_RATIO_H_FLIP = False

8. Soft NMS 参数

# ---------------------------------------------------------------------------- #
# Soft NMS
# ---------------------------------------------------------------------------- #
__C.TEST.SOFT_NMS = AttrDict()

# True,采用 soft NMS 代替 standard NMS
__C.TEST.SOFT_NMS.ENABLED = False
# See soft NMS paper for definition of these options
__C.TEST.SOFT_NMS.METHOD = b'linear'
__C.TEST.SOFT_NMS.SIGMA = 0.5
# For the soft NMS overlap threshold, we simply use TEST.NMS

9. 边界框投票Bounding box voting

# ---------------------------------------------------------------------------- #
# 边界框投票Bounding box voting (from the Multi-Region CNN paper)
# ---------------------------------------------------------------------------- #
__C.TEST.BBOX_VOTE = AttrDict()

# True,使用 box 投票
__C.TEST.BBOX_VOTE.ENABLED = False

# NMS步骤 使用的 TEST.NMS 阈值
#   对每个NMS得到的 box,VOTE_TH 重叠阈值用于选择投票 boxes (IoU >= VOTE_TH)
__C.TEST.BBOX_VOTE.VOTE_TH = 0.8

# 边界框投票时组合分数的方法
#  可选参数: ('ID', 'AVG', 'IOU_AVG', 'GENERALIZED_AVG', 'QUASI_SUM')
__C.TEST.BBOX_VOTE.SCORING_METHOD = b'ID'

# scoring方法的超参数(不同方法,意义不同)
__C.TEST.BBOX_VOTE.SCORING_METHOD_BETA = 1.0

10. 模型参数 Model options

# ---------------------------------------------------------------------------- #
# 模型参数
# ---------------------------------------------------------------------------- #
__C.MODEL = AttrDict()

# 使用模型的类型
# 字符串形式,与 modeling.model_builder 内的对应函数名一致,
# (e.g., 'generalized_rcnn', 'mask_rcnn', ...)
__C.MODEL.TYPE = b''

# 使用的骨干卷积网络 backbone conv body
# 字符串形式,与 modeling.model_builder 内的对应函数名一致,
# (e.g., 'FPN.add_fpn_ResNet101_conv5_body' 指定 ResNet-101-FPN 为骨干网络backbone)
__C.MODEL.CONV_BODY = b''

# 数据集内的类别classes 数;必须设定
# E.g., 81 for COCO (80 foreground + 1 background)
__C.MODEL.NUM_CLASSES = -1

# 采用类别未知边界框回归器(class agnostic bounding box regressor);
# 而不是默认逐类回归器(default per-class regressor)
__C.MODEL.CLS_AGNOSTIC_BBOX_REG = False

# (dx, dy, dw, dh)的默认权重,用于归一化 bbox 回归目标targets
# 经验值,逼近到单位方差目标(unit variance targets)
__C.MODEL.BBOX_REG_WEIGHTS = (10., 10., 5., 5.)

# FASTER_RCNN 的意思取决于其内容(training vs. inference):
# 1) 训练时,FASTER_RCNN = True 表示使用 end-to-end 训练方式联合训练 RPN 子网络和 Fast R-CNN 子网络
#    (Faster R-CNN = RPN + Fast R-CNN).
# 2) 推断时,FASTER_RCNN = True 表示使用模型的 RPN 子网络来生成 proposals,而不是预先计算的 proposals.
#    即使 Faster R-CNN 模型是逐阶段训练的(即,交替训练 RPN 和 Fast R-CNN 得到的),
#    也可以在推断时使用 FASTER_RCNN = True.
__C.MODEL.FASTER_RCNN = False

# True,表示模型输出示例分割预测结果 instance mask predictions (如 Mask R-CNN)
__C.MODEL.MASK_ON = False

# True,表示模型输出关键点预测结果 keypoint predictions (如 Mask R-CNN for keypoints)
__C.MODEL.KEYPOINTS_ON = False

# True,表示模型的计算终止于生成 RPN proposals
# (i.e., 只输出 proposals,不进行真正的目标检测)
__C.MODEL.RPN_ONLY = False

# Caffe2 网络net 执行类型
# 使用 'prof_dag' 进行分析统计
__C.MODEL.EXECUTION_TYPE = b'dag'

11.RetinaNet 参数

# ----------------------------------------------------------------------------#
# RetinaNet 参数
# ----------------------------------------------------------------------------#
__C.RETINANET = AttrDict()

# True,使用 RetinaNet (instead of Fast/er/Mask R-CNN/R-FCN/RPN)
__C.RETINANET.RETINANET_ON = False

# 所使用的 Anchor 长宽比aspect ratios
__C.RETINANET.ASPECT_RATIOS = (0.5, 1.0, 2.0)

# 每个 octave 的 Anchor 缩放尺度scale
__C.RETINANET.SCALES_PER_OCTAVE = 3

# 在每个 FPN 层次,基于其缩放尺度scale、长宽比aspect_ratio、层次步长stride of the level 生成 anchors,
# 并对得到的 anchor 乘以因子 ANCHOR_SCALE
__C.RETINANET.ANCHOR_SCALE = 4

# cls 和 bbox 使用的卷积数
# NOTE: this doesn't include the last conv for logits
__C.RETINANET.NUM_CONVS = 4

# bbox_regress loss 的权重
__C.RETINANET.BBOX_REG_WEIGHT = 1.0

# bbox regression 的 Smooth L1 loss beta
__C.RETINANET.BBOX_REG_BETA = 0.11

# 推断时,每个 FPN 层,在NMS处理前,基于 cls score 选择的 #locs
# During inference, #locs to select based on cls score before NMS is performed
# per FPN level
__C.RETINANET.PRE_NMS_TOP_N = 1000

# 标记 anchor 为 positive 的 IoU 重叠率
# Anchors with >= iou overlap 标记为 positive
__C.RETINANET.POSITIVE_OVERLAP = 0.5

# 标记 anchor 为 negative 的 IoU 重叠率
# Anchors with < iou overlap 标记为 negative
__C.RETINANET.NEGATIVE_OVERLAP = 0.4

# Focal loss 参数: alpha
__C.RETINANET.LOSS_ALPHA = 0.25

# Focal loss 参数: gamma
__C.RETINANET.LOSS_GAMMA = 2.0

# 训练开始时的 positives 的先验概率Prior prob
# 用于设置logits layer 的bias 初始化
__C.RETINANET.PRIOR_PROB = 0.01

# 是否共享 classification 和 bbox 分支
__C.RETINANET.SHARE_CLS_BBOX_TOWER = False

# 使用类别特定边界框回归(class specific bounding box regression),
# 而不是默认类别位置回归(default class agnostic regression)
__C.RETINANET.CLASS_SPECIFIC_BBOX = False

# classification 分支训练是否使用 softmax
__C.RETINANET.SOFTMAX = False

# 推断时的 cls score 阈值
# anchors with score > INFERENCE_TH 用于推断
__C.RETINANET.INFERENCE_TH = 0.05

12. 求解器 Solver 参数

# ----------------------------------------------------------------------------#
# 求解器参数Solver
# 所有的 solver 参数被精确指定;意味着,如果训练从 1 GPU 切换到 N GPUs,必须调整对应的 solver 参数.
# 建议使用 gradual warmup 和 linear learning rate scaling rule,基于于论文
# "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" Goyal et al.
# https://arxiv.org/abs/1706.02677
# ---------------------------------------------------------------------------- 
__C.SOLVER = AttrDict()

# 指定方案的基础学习率Base learning rate
__C.SOLVER.BASE_LR = 0.001

# 学习率策略,如 utils.lr_policy 中参数的对应函数)
# E.g., 'step', 'steps_with_decay', ...
__C.SOLVER.LR_POLICY = b'step'

# 一些 LR 策略例示:
# 'step'
#   lr = SOLVER.BASE_LR * SOLVER.GAMMA ** (cur_iter // SOLVER.STEP_SIZE)
# 'steps_with_decay'
#   SOLVER.STEPS = [0, 60000, 80000]
#   SOLVER.GAMMA = 0.1
#   lr = SOLVER.BASE_LR * SOLVER.GAMMA ** current_step
#   iters [0, 59999] are in current_step = 0, iters [60000, 79999] are in
#   current_step = 1, and so on
# 'steps_with_lrs'
#   SOLVER.STEPS = [0, 60000, 80000]
#   SOLVER.LRS = [0.02, 0.002, 0.0002]
#   lr = LRS[current_step]

# 指定方案的超参数
# 对于 'step',在每一步,当前 LR 乘以因子 SOLVER.GAMMA
__C.SOLVER.GAMMA = 0.1

# 'steps'策略均匀化步长Uniform step size
__C.SOLVER.STEP_SIZE = 30000

# 'steps_with_decay' 和 'steps_with_lrs'策略非均匀化步长 Non-uniform step iterations
__C.SOLVER.STEPS = []

# 采用 'steps_with_lrs' 策略的学习率
__C.SOLVER.LRS = []

# SGD 迭代的最多次数
__C.SOLVER.MAX_ITER = 40000

# SGD 动量 Momentum
__C.SOLVER.MOMENTUM = 0.9

# L2 正则化参数
__C.SOLVER.WEIGHT_DECAY = 0.0005

# SOLVER.BASE_LR 热身的 SGD 迭代次数
__C.SOLVER.WARM_UP_ITERS = 500

# 从 SOLVER.BASE_LR * SOLVER.WARM_UP_FACTOR 开始热身
__C.SOLVER.WARM_UP_FACTOR = 1.0 / 3.0

# WARM_UP_METHOD 方法可以是 'constant' 或 'linear' (i.e., gradual)
__C.SOLVER.WARM_UP_METHOD = 'linear'

# 当更新学习率时,采用 new_lr / old_lr 对动量momentum更新历史的缩放尺度scale:
# (this is correct given MomentumSGDUpdateOp)
__C.SOLVER.SCALE_MOMENTUM = True
# 仅当相对 LR 变化大于阈值时,才进行修正
# (避免小因子缩放动量momentum 时 linear warm up 所发生的变换;
#  如果 LR 变换较大时,momentum 缩放才比较重要)
# prevents ever change in linear warm up from scaling the momentum by a tiny
# amount; momentum scaling is only important if the LR change is large)
__C.SOLVER.SCALE_MOMENTUM_THRESHOLD = 1.1

# Suppress logging of changes to LR unless the relative change exceeds this
# threshold (prevents linear warm up from spamming the training log)
__C.SOLVER.LOG_LR_CHANGE_THRESHOLD = 1.1

13. Fast R-CNN 参数

# ---------------------------------------------------------------------------- 
# Fast R-CNN 参数
# ---------------------------------------------------------------------------- #
__C.FAST_RCNN = AttrDict()

# 用于边界框分类和回归的 RoI head 类型
# 字符串形式,必须与  modeling.model_builder 中对应的函数一致
# (e.g., 'head_builder.add_roi_2mlp_head' 指定了两个隐层的 MLP)
__C.FAST_RCNN.ROI_BOX_HEAD = b''

# 当使用 MLP 作为 RoI box head 时的隐层维度
__C.FAST_RCNN.MLP_HEAD_DIM = 1024

# RoI 变换函数(e.g., RoIPool or RoIAlign)
# (RoIPoolF 与 RoIPool 相同; 忽略尾部的 'F')
__C.FAST_RCNN.ROI_XFORM_METHOD = b'RoIPoolF'

# RoIAlign 中,网格采样(grid sampling points)点数(一般为 2)
# 只在 RoIAlign 中使用
__C.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO = 0

# RoI 变换的输出分辨率
# 注:某些模型可能对于其可以使用什么有约束;
# e.g. they use pretrained FC layers like in VGG16, and will ignore this option
__C.FAST_RCNN.ROI_XFORM_RESOLUTION = 14

14. RPN 参数

# ----------------------------------------------------------------------------#
# RPN 参数
# ----------------------------------------------------------------------------#
__C.RPN = AttrDict()

# [推测值;不直接在 config 中设定]
# True,表示模型包含 RPN 子网络
__C.RPN.RPN_ON = False

# 关于缩放网络输入的 RPN anchor 尺寸,以绝对像素值的形式
# RPN anchor sizes given in absolute pixels w.r.t. the scaled network input
# Note: these options are *not* used by FPN RPN; see FPN.RPN* options
__C.RPN.SIZES = (64, 128, 256, 512)

# RPN attached 的特征图步长
__C.RPN.STRIDE = 16

# RPN anchor 的长宽比aspect ratios
__C.RPN.ASPECT_RATIOS = (0.5, 1, 2)

15. FPN 参数

# --------------------------------------------------------------------------- #
# FPN 参数
# --------------------------------------------------------------------------- #
__C.FPN = AttrDict()

# True,开启 FPN
__C.FPN.FPN_ON = False

# FPN 特征层的通道维度Channel dimension
__C.FPN.DIM = 256

# True,初始化侧向连接lateral connections 输出 0
__C.FPN.ZERO_INIT_LATERAL = False

# 最粗糙coarsest FPN 层的步长
# 用于将输入正确地补零,是需要的
__C.FPN.COARSEST_STRIDE = 32

#
# FPN 可以只是 RPN、或只是目标检测,或两者都用.
#

# True, 采用 FPN 用于目标检测 RoI 变换
__C.FPN.MULTILEVEL_ROIS = False
# RoI-to-FPN 层的映射启发式 超参数
__C.FPN.ROI_CANONICAL_SCALE = 224  # s0
__C.FPN.ROI_CANONICAL_LEVEL = 4  # k0: where s0 maps to
# FPN 金字塔pyramid 的最粗糙层Coarsest level
__C.FPN.ROI_MAX_LEVEL = 5
# FPN 金字塔pyramid 的最精细层Finest level
__C.FPN.ROI_MIN_LEVEL = 2

# True,在 RPN 中使用 FPN
__C.FPN.MULTILEVEL_RPN = False
# FPN 金字塔pyramid 的最粗糙层Coarsest level
__C.FPN.RPN_MAX_LEVEL = 6
# FPN 金字塔pyramid 的最精细层Finest level
__C.FPN.RPN_MIN_LEVEL = 2
# FPN RPN anchor 长宽比aspect ratios
__C.FPN.RPN_ASPECT_RATIOS = (0.5, 1, 2)
# 在 RPN_MIN_LEVEL 上 RPN anchors 开始的尺寸
# RPN anchors start at this size on RPN_MIN_LEVEL
# The anchor size doubled each level after that
# With a default of 32 and levels 2 to 6, we get anchor sizes of 32 to 512
__C.FPN.RPN_ANCHOR_START_SIZE = 32
# 使用额外的 FPN 层levels, as done in the RetinaNet paper
__C.FPN.EXTRA_CONV_LEVELS = False

16. Mask R-CNN 参数

# --------------------------------------------------------------------------- #
# Mask R-CNN 参数 ("MRCNN" means Mask R-CNN)
# --------------------------------------------------------------------------- #
__C.MRCNN = AttrDict()

# 实例 mask 预测所用的 RoI head 类型
# 字符串形式,与 modeling.model_builder 对应的函数一致
# (e.g., 'mask_rcnn_heads.ResNet_mask_rcnn_fcn_head_v1up4convs')
__C.MRCNN.ROI_MASK_HEAD = b''

# 预测 mask 的分辨率
__C.MRCNN.RESOLUTION = 14

# RoI 变换函数和相关参数
__C.MRCNN.ROI_XFORM_METHOD = b'RoIAlign'

# RoI 变换函数 (e.g., RoIPool or RoIAlign)
__C.MRCNN.ROI_XFORM_RESOLUTION = 7

# RoIAlign 中网格采样点(grid sampling points)数(通常为2)
# 只用于 RoIAlign
__C.MRCNN.ROI_XFORM_SAMPLING_RATIO = 0

# mask head 中的通道channels 数
__C.MRCNN.DIM_REDUCED = 256

# mask head 中使用 dilated convolution
__C.MRCNN.DILATION = 2

# 预测 masks 的上采样因子
__C.MRCNN.UPSAMPLE_RATIO = 1

# True,采用全连接FC 层来预测最终的 masks
# False,采用卷积conv 层来预测最终的 masks
__C.MRCNN.USE_FC_OUTPUT = False

# mask head 和mask 输出层的权重初始化方法
__C.MRCNN.CONV_INIT = b'GaussianFill'

# True,使用类别特定 mask 预测(class specific mask predictions)
# False,使用类别未知 mask 预测(class agnostic mask predictions)
__C.MRCNN.CLS_SPECIFIC_MASK = True

# masks 的 multi-task loss 的权重
__C.MRCNN.WEIGHT_LOSS_MASK = 1.0

# soft masks 转换为 hard masks 的二值化阈值(Binarization threshold)
__C.MRCNN.THRESH_BINARIZE = 0.5

17. Keyoint Mask R-CNN 参数

# --------------------------------------------------------------------------- #
# Keyoint Mask R-CNN 参数 ("KRCNN" = Mask R-CNN with Keypoint support)
# --------------------------------------------------------------------------- #
__C.KRCNN = AttrDict()

# 实例关键点预测的 RoI head 的类型
# 字符串形式,与 modeling.model_builder 对应的函数一致
# (e.g., 'keypoint_rcnn_heads.add_roi_pose_head_v1convX')
__C.KRCNN.ROI_KEYPOINTS_HEAD = b''

# 输出 feature map 的尺寸(计算 loss 的尺寸) e.g., 56x56
__C.KRCNN.HEATMAP_SIZE = -1

# 上采样因子,使用双线性插值(bilinear interpolation)来上采用最终的 heatmap
__C.KRCNN.UP_SCALE = -1

# 对关键点 head 先验(prior) 计算的特征表示,采用 ConvTranspose 层处理,
# 以预测每个关键点的 heatmaps.
__C.KRCNN.USE_DECONV = False
# ConvTranspose 层产生的隐特征表示的通道维度(Channel dimension)
__C.KRCNN.DECONV_DIM = 256

# 采用 ConvTranspose 层来预测每个关键点的 heatmaps
__C.KRCNN.USE_DECONV_OUTPUT = False

# 在关键点 head 中使用 dilation
__C.KRCNN.DILATION = 1

# 所有的 ConvTranspose 操作的 kernels size
__C.KRCNN.DECONV_KERNEL = 4

# 数据集中关键点数(e.g., 17 for COCO)
__C.KRCNN.NUM_KEYPOINTS = -1

# 关键点 hrad 中的 stacked Conv layers 数
__C.KRCNN.NUM_STACKED_CONVS = 8

# 关键点 head 输出的特征表示的维度
__C.KRCNN.CONV_HEAD_DIM = 256

# 关键点 head 使用的 Conv kernel size
__C.KRCNN.CONV_HEAD_KERNEL = 3
# Conv 权重初始化
# Conv kernel weight filling function
__C.KRCNN.CONV_INIT = b'GaussianFill'

# True,基于 OKS 使用 NMS
__C.KRCNN.NMS_OKS = False

# Source of keypoint confidence
#   可选参数: ('bbox', 'logit', 'prob')
__C.KRCNN.KEYPOINT_CONFIDENCE = b'bbox'

# 标准的 RoI XFROM 参数 (如 FAST_RCNN 或 MRCNN 参数)
__C.KRCNN.ROI_XFORM_METHOD = b'RoIAlign'
__C.KRCNN.ROI_XFORM_RESOLUTION = 7
__C.KRCNN.ROI_XFORM_SAMPLING_RATIO = 0

# mini-batch 中必须存在的标记关键点的最小数
# 如果少于 MIN_KEYPOINT_COUNT_FOR_VALID_MINIBATCH,则丢弃该 mini-batch
__C.KRCNN.MIN_KEYPOINT_COUNT_FOR_VALID_MINIBATCH = 20

# 当从 heatmap 中推断关键点位置时,对低于最小尺寸INFERENCE_MIN_SIZE 的heatmap不进行缩放scale
__C.KRCNN.INFERENCE_MIN_SIZE = 0

# 用于关键点的 multi-task loss 权重
# 推荐值:
#   - use 1.0 if KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS is True
#   - use 4.0 if KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS is False
__C.KRCNN.LOSS_WEIGHT = 1.0

# True,根据 mini-batch 内可见(visible)关键点总数进行归一化
# False,根据 mini-batch 内存在的关键点总数进行归一化
__C.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS = True

18. R-FCN 参数

# --------------------------------------------------------------------------- #
# R-FCN 参数
# --------------------------------------------------------------------------- #
__C.RFCN = AttrDict()

# Position-sensitive RoI pooling output grid size (height and width)
__C.RFCN.PS_GRID_SIZE = 3

19. ResNets 参数

# --------------------------------------------------------------------------- #
# ResNets 参数 ("ResNets" = ResNet and ResNeXt)
# --------------------------------------------------------------------------- #
__C.RESNETS = AttrDict()

# groups 数; 1 ==> ResNet; > 1 ==> ResNeXt
__C.RESNETS.NUM_GROUPS = 1

# 每个 group 的 aseline width
__C.RESNETS.WIDTH_PER_GROUP = 64

# 在 1x1 filter 放置 stride 2 的 conv
# True, 只用于原始的 MSRA ResNet;
# False, 用于 C2 和 Torch 模型 
__C.RESNETS.STRIDE_1X1 = True

# 残差变换函数
# Residual transformation function
__C.RESNETS.TRANS_FUNC = b'bottleneck_transformation'

# 在 stage "res5" 采用 dilation
__C.RESNETS.RES5_DILATION = 1

20. Misc 参数

# --------------------------------------------------------------------------- #
# Misc 参数
# --------------------------------------------------------------------------- #

# GPUs 数 (同时用于 training 和 testing)
__C.NUM_GPUS = 1

# True, 使用 NCCL 加速;可能遇到 deadlocks 问题
# False,使用 muji
__C.USE_NCCL = False

# 图片坐标到特征图 feature map 坐标映射时,一些在图片空间不同坐标的 boxes 的 feature 坐标可能相同.
# 如果 DEDUP_BOXES > 0,DEDUP_BOXES 作为识别重复 boxes 的缩放因子scale factor.
# 1/16 is correct for {Alex,Caffe}Net, VGG_CNN_M_1024, and VGG16
__C.DEDUP_BOXES = 1 / 16.

# 修剪边界框变换预测结果(bounding box transformation predictions),以避免 np.exp 出现溢出(overflowing).
# 基于将 16 pixel anchor 缩放scale 到 10000 pixels 的启发式值
__C.BBOX_XFORM_CLIP = np.log(1000. / 16.)

# 像素均值(BGR 顺序) - (1, 1, 3) array数组形式
# 对于所有网络采用相同的像素均值,即使不是很精确
# "Fun" fact: the history of where these values comes from is lost
__C.PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])

# 用于重复实现结果
# For reproducibility...but not really because modern fast GPU libraries use
# non-deterministic op implementations
__C.RNG_SEED = 3

# A small number that's used many times
__C.EPS = 1e-14

# 项目根目录
__C.ROOT_DIR = os.getcwd()

# Output basedir
__C.OUTPUT_DIR = b'/tmp'

# matlab 
# Name (or path to) the matlab executable
__C.MATLAB = b'matlab'

# 基于 memonger gradient blob sharing 减少内存占用
# Reduce memory usage with memonger gradient blob sharing
__C.MEMONGER = True

# 如果 forward pass 激活可以共享,可以进一步优化内存
# Futher reduce memory by allowing forward pass activations to be shared when
# possible. Note that this will cause activation blob inspection (values,
# shapes, etc.) to be meaningless when activation blobs are reused.
__C.MEMONGER_SHARE_ACTIVATIONS = False

# 检测结果可视化 Dump detection visualizations
__C.VIS = False

# 可视化的score阈值
# Score threshold for visualization
__C.VIS_TH = 0.9

# Expected 检测,列表list 形式,每一元素有四个值:(dataset, task, metric, expected value)
# 如: [['coco_2014_minival', 'box_proposal', 'AR@1000', 0.387]]
__C.EXPECTED_RESULTS = []
# 与 EXPECTED_RESULTS 相比的绝对和相对偏差tolerance
__C.EXPECTED_RESULTS_RTOL = 0.1
__C.EXPECTED_RESULTS_ATOL = 0.005
# 如果 EXPECTED_RESULTS 失败,发送邮件
__C.EXPECTED_RESULTS_EMAIL = b''

# URL 下载的模型Models 和 proposals 的缓存目录
__C.DOWNLOAD_CACHE = b'/tmp/detectron-download-cache'

21. 聚类Cluster 参数

# --------------------------------------------------------------------------- #
# 聚类Cluster 参数
# --------------------------------------------------------------------------- #
__C.CLUSTER = AttrDict()

# True,表示 code 以聚类环境运行
# Flag to indicate if the code is running in a cluster environment
__C.CLUSTER.ON_CLUSTER = False

22. 其它参数

# --------------------------------------------------------------------------- #
# 丢弃参数
# 如果参数已经从代码里删除,但又不想修改已有的 yaml 参数配置,则可以将完整的 config key 作为字符串添加到下面的集.
# --------------------------------------------------------------------------- #
_DEPCRECATED_KEYS = set(
    (
        'FINAL_MSG',
        'MODEL.DILATION',
        'ROOT_GPU_ID',
        'RPN.ON',
        'TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED',
        'TRAIN.DROPOUT',
        'USE_GPU_NMS',
        'TEST.NUM_TEST_IMAGES',
    )
)

# --------------------------------------------------------------------------- #
# 重命名参数
# 如果重命名了一个 config 里的参数,可以在以下 dict 里记录 old name 到 new name 的映射.
# 
# Optionally, if the type also changed, you can
# make the value a tuple that specifies first the renamed key and then
# instructions for how to edit the config file.
# --------------------------------------------------------------------------- #
_RENAMED_KEYS = {
    'EXAMPLE.RENAMED.KEY': 'EXAMPLE.KEY',  # Dummy example to follow
    'MODEL.PS_GRID_SIZE': 'RFCN.PS_GRID_SIZE',
    'MODEL.ROI_HEAD': 'FAST_RCNN.ROI_BOX_HEAD',
    'MRCNN.MASK_HEAD_NAME': 'MRCNN.ROI_MASK_HEAD',
    'TRAIN.DATASET': (
        'TRAIN.DATASETS',
        "Also convert to a tuple, e.g., " +
        "'coco_2014_train' -> ('coco_2014_train',) or " +
        "'coco_2014_train:coco_2014_valminusminival' -> " +
        "('coco_2014_train', 'coco_2014_valminusminival')"
    ),
    'TRAIN.PROPOSAL_FILE': (
        'TRAIN.PROPOSAL_FILES',
        "Also convert to a tuple, e.g., " +
        "'path/to/file' -> ('path/to/file',) or " +
        "'path/to/file1:path/to/file2' -> " +
        "('path/to/file1', 'path/to/file2')"
    ),
}

23. 参数处理函数

# -------------------------------------------------------------------------- #
# 参数处理的一些定义函数
# -------------------------------------------------------------------------- #

def assert_and_infer_cfg(cache_urls=True):
    if __C.MODEL.RPN_ONLY or __C.MODEL.FASTER_RCNN:
        __C.RPN.RPN_ON = True
    if __C.RPN.RPN_ON or __C.RETINANET.RETINANET_ON:
        __C.TEST.PRECOMPUTED_PROPOSALS = False
    if cache_urls:
        cache_cfg_urls()


def cache_cfg_urls():
    """Download URLs in the config, cache them locally, and rewrite cfg to make
    use of the locally cached file.
    """
    __C.TRAIN.WEIGHTS = cache_url(__C.TRAIN.WEIGHTS, __C.DOWNLOAD_CACHE)
    __C.TEST.WEIGHTS = cache_url(__C.TEST.WEIGHTS, __C.DOWNLOAD_CACHE)
    __C.TRAIN.PROPOSAL_FILES = tuple(
        [cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TRAIN.PROPOSAL_FILES]
    )
    __C.TEST.PROPOSAL_FILES = tuple(
        [cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TEST.PROPOSAL_FILES]
    )


def get_output_dir(datasets, training=True):
    """Get the output directory determined by the current global config."""
    assert isinstance(datasets, (tuple, list, basestring)), \
        'datasets argument must be of type tuple, list or string'
    is_string = isinstance(datasets, basestring)
    dataset_name = datasets if is_string else ':'.join(datasets)
    tag = 'train' if training else 'test'
    # ////
    outdir = osp.join(__C.OUTPUT_DIR, tag, dataset_name, __C.MODEL.TYPE)
    if not osp.exists(outdir):
        os.makedirs(outdir)
    return outdir


def merge_cfg_from_file(cfg_filename):
    """Load a yaml config file and merge it into the global config."""
    with open(cfg_filename, 'r') as f:
        yaml_cfg = AttrDict(yaml.load(f))
    _merge_a_into_b(yaml_cfg, __C)


def merge_cfg_from_cfg(cfg_other):
    """Merge `cfg_other` into the global config."""
    _merge_a_into_b(cfg_other, __C)


def merge_cfg_from_list(cfg_list):
    """Merge config keys, values in a list (e.g., from command line) into the
    global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.
    """
    assert len(cfg_list) % 2 == 0
    for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
        if _key_is_deprecated(full_key):
            continue
        if _key_is_renamed(full_key):
            _raise_key_rename_error(full_key)
        key_list = full_key.split('.')
        d = __C
        for subkey in key_list[:-1]:
            assert subkey in d, 'Non-existent key: {}'.format(full_key)
            d = d[subkey]
        subkey = key_list[-1]
        assert subkey in d, 'Non-existent key: {}'.format(full_key)
        value = _decode_cfg_value(v)
        value = _check_and_coerce_cfg_value_type(
            value, d[subkey], subkey, full_key
        )
        d[subkey] = value


def _merge_a_into_b(a, b, stack=None):
    """Merge config dictionary a into config dictionary b, clobbering the
    options in b whenever they are also specified in a.
    """
    assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict'
    assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict'

    for k, v_ in a.items():
        full_key = '.'.join(stack) + '.' + k if stack is not None else k
        # a must specify keys that are in b
        if k not in b:
            if _key_is_deprecated(full_key):
                continue
            elif _key_is_renamed(full_key):
                _raise_key_rename_error(full_key)
            else:
                raise KeyError('Non-existent config key: {}'.format(full_key))

        v = copy.deepcopy(v_)
        v = _decode_cfg_value(v)
        v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)

        # Recursively merge dicts
        if isinstance(v, AttrDict):
            try:
                stack_push = [k] if stack is None else stack + [k]
                _merge_a_into_b(v, b[k], stack=stack_push)
            except BaseException:
                raise
        else:
            b[k] = v


def _key_is_deprecated(full_key):
    if full_key in _DEPCRECATED_KEYS:
        logger.warn(
            'Deprecated config key (ignoring): {}'.format(full_key)
        )
        return True
    return False


def _key_is_renamed(full_key):
    return full_key in _RENAMED_KEYS


def _raise_key_rename_error(full_key):
    new_key = _RENAMED_KEYS[full_key]
    if isinstance(new_key, tuple):
        msg = ' Note: ' + new_key[1]
        new_key = new_key[0]
    else:
        msg = ''
    raise KeyError(
        'Key {} was renamed to {}; please update your config.{}'.
        format(full_key, new_key, msg)
    )


def _decode_cfg_value(v):
    """Decodes a raw config value (e.g., from a yaml config files or command
    line argument) into a Python object.
    """
    # Configs parsed from raw yaml will contain dictionary keys that need to be
    # converted to AttrDict objects
    if isinstance(v, dict):
        return AttrDict(v)
    # All remaining processing is only applied to strings
    if not isinstance(v, basestring):
        return v
    # Try to interpret `v` as a:
    #   string, number, tuple, list, dict, boolean, or None
    try:
        v = literal_eval(v)
    # The following two excepts allow v to pass through when it represents a
    # string.
    #
    # Longer explanation:
    # The type of v is always a string (before calling literal_eval), but
    # sometimes it *represents* a string and other times a data structure, like
    # a list. In the case that v represents a string, what we got back from the
    # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
    # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
    # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
    # will raise a SyntaxError.
    except ValueError:
        pass
    except SyntaxError:
        pass
    return v


def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
    """Checks that `value_a`, which is intended to replace `value_b` is of the
    right type. The type is correct if it matches exactly or is one of a few
    cases in which the type can be easily coerced.
    """
    # The types must match (with some exceptions)
    type_b = type(value_b)
    type_a = type(value_a)
    if type_a is type_b:
        return value_a

    # Exceptions: numpy arrays, strings, tuple<->list
    if isinstance(value_b, np.ndarray):
        value_a = np.array(value_a, dtype=value_b.dtype)
    elif isinstance(value_b, basestring):
        value_a = str(value_a)
    elif isinstance(value_a, tuple) and isinstance(value_b, list):
        value_a = list(value_a)
    elif isinstance(value_a, list) and isinstance(value_b, tuple):
        value_a = tuple(value_a)
    else:
        raise ValueError(
            'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
            'key: {}'.format(type_b, type_a, value_b, value_a, full_key)
        )
    return value_a

你可能感兴趣的:(Caffe2,Caffe2,caffe2,detectron,config)