提供了上层框架需要的 hook 机制以及可以直接使用的 runner
MMCV 提供了非常多的高性能 cuda 算子及其 python 接口
可参考https://zhuanlan.zhihu.com/p/336097883
fileio中的核心组件,设计文件读写。
mmcv提供了底层逻辑的读写handler,目前支持的有.json/.yaml/.yml/.pickle/.pkl文件
# 具体用法
import mmcv
# load data from a file
data = mmcv.load('test.json')
data = mmcv.load('test.yaml')
data = mmcv.load('test.pkl')
mmcv.dump(data, 'out.pkl')
mmcv支持自定义拓展的文件格式(即需要的文件格式不在上述列表),链接中给了.npy的例子。
其作用是对外提供统一的文件内容获取 API,主要用于训练过程中数据的后端读取,通过用户选择默认或者自定义不同的 FileClient 后端,可以轻松实现文件缓存、文件加速读取等等功能。
https://zhuanlan.zhihu.com/p/339190576
FileClinet用法示例,其实际调用在 mmseg/datasets/pipelines/loading.py/LoadImageFromFile
类中
class LoadImageFromFile(object): # 加载图片到内存中
"""Load an image from file.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
'cv2'
"""
def __init__(self,
to_float32=False,
color_type='color',
file_client_args=dict(backend='disk'),
imdecode_backend='cv2'):
self.to_float32 = to_float32
self.color_type = color_type
# 默认是disk后端
self.file_client_args = file_client_args.copy()
self.file_client = None
self.imdecode_backend = imdecode_backend
def __call__(self, results):
"""Call functions to load image and get image meta information.
Args:
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results.get('img_prefix') is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']
# 读取图片字节内容
img_bytes = self.file_client.get(filename)
# 对字节内容进行解码
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(to_float32={self.to_float32},'
repr_str += f"color_type='{self.color_type}',"
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
return repr_str
扩展开发示例提供了img文件和annotations 文件不在同一个地方的例子。
Config 主要是提供各种格式的配置文件解析功能,包括 py、json、ymal 和 yml,是一个非常基础常用类
https://zhuanlan.zhihu.com/p/346203167
mmseg/configs目录下的很多文件是用这个方法定义的
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
# 可以通过 .属性方式访问,比较方便
cfg.b.b1 # [0, 1]
该功能最为常用,配置文件可以是 py、yaml、yml 和 json 格式。
cfg = Config.fromfile('tests/data/config/a.py')
cfg.filename
cfg.item4 # 'test'
cfg # 打印 config path,和字典内容...
假设h.py文件里面存储的内容是:
cfg_dict = dict(
item1='{{fileBasename}}',
item2='{{fileDirname}}',
item3='abc_{{fileBasenameNoExtension }}')
则可以通过参数 use_predefined_variables
实现自动替换预定义变量功能
# cfg_file 文件名是 h.py
cfg = Config.fromfile(cfg_file, use_predefined_variables=True)
print(cfg.pretty_text)
# 输出
item1 = 'h.py'
item2 = 'config 文件路径'
item3 = 'abc_h'
该参数主要用途是自动替换 Config 类中已经预定义好的变量模板为真实值,在某些场合有用,目前支持 4 个变量:fileDirname、fileBasename、fileBasenameNoExtension 和 fileExtname
,预定义变量参考自 VS Code
如果 use_predefined_variables=False
( 默认为 True ),则不会进行任何替换。
Config.fromfile
函数除了有 filename
和 use_predefined_variables
参数外,还有 import_custom_modules
,默认是 True,即当 cfg中存在 custom_imports
键时候会对里面的内容进行自动导入,其输入格式要么是 str 要么是 list[str],表示待导入的模块路径,一个典型用法是:
在mmseg/datasets目录下新建greenscreen.py时,需要在__init__
里面加入
from .greenscreen import GreenScreenDataset
但是上述做法在某些场景下会比较麻烦。例如该模块处于非常深的层级,那么就需要逐层修改 __init__.py
,有了本参数,便可以采用如下做法:
# .py 文件里面存储如下内容
custom_imports = dict(
imports=['mmdet.models.backbones.mobilenet'],
allow_failed_imports=False)
# 自动导入 mmdet.models.backbones.mobilenet
Config.fromfile(cfg_file, import_custom_modules=True)
(1) 从 base 文件中合并 Config 支持基于单个 base 配置文件,然后合并其余配置,最终得到一个汇总配置,该功能在各大上层框架中使用非常频繁,可以极大的增加配置复用性。一个典型用法是:
# base.py 内容
item1 = [1, 2]
item2 = {'a': 0}
item3 = True
item4 = 'test'
# d.py 内容
_base_ = './base.py'
item1 = [2, 3]
item2 = {'a': 1}
item3 = False
item4 = 'test_base'
# 用法
cfg = Config.fromfile('d.py')
# 输出
item1 = [2, 3]
item2 = dict(a=1)
item3 = False
item4 = 'test_base'
(2) 从多个 base 文件中合并 Config 同时也支持多个 base 文件合并得到最终配置,用户只需要在非 base 配置文件中将类似 _base_ = './base.py'
改成 _base_ = ['./base.py',...]
即可。如配置configs/unet/deeplabv3_unet_s5_d16_256x256_40k_greenscreen.py时
_base_ = [
'../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/greenscreen.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
model = dict(test_cfg=dict(crop_size=(256, 256), stride=(170, 170)))
evaluation = dict(metric='mDice')
需要特别强调的是:
_base_
,否则程序不知道哪个字段才是 baseDuplicate Key Error
,因为此时不知道以哪个配置为主(3) 合并字典到配置 通过 cfg.merge_from_dict
函数接口可以实现对字典内容进行合并,典型用法如下:
cfg_file = osp.join(data_path, 'config/a.py')
cfg = Config.fromfile(cfg_file)
input_options = {'item2.a': 1, 'item2.b': 0.1, 'item3': False}
cfg.merge_from_dict(input_options)
# 原始 a.py 内容为:
item1 = [1, 2]
item2 = {'a': 0}
item3 = True
item4 = 'test'
# 进行合并后, cfg 内容
item1 = [1, 2]
item2 = dict(a=1, b=0.1)
item3 = False
item4 = 'test'
(4) allow_list_keys 模式合并 假设某个配置文件中内容为:
item = [dict(a=0), dict(b=0, c=0)]
可以通过如下方式修改 list 内容:
input_options = {'item.0.a': 1, 'item.1.b': 1}
cfg.merge_from_dict(input_options, allow_list_keys=True)
# 输出
item = [dict(a=1), dict(b=1, c=0)]
如果 input_options
内部索引越界或者 allow_list_keys=False
(默认是 True),则会报错。
(5) 允许删掉特定内容 该功能也比较常用,思考如下场景:在 RetinaNet 算法中,其采用的 bbox 回归 loss 配置如下:
loss_bbox=dict(type='L1Loss', loss_weight=1.0,其他参数)
上述配置是在 base 文件中,但是在 FASF 算法中采用的是 IOULoss,现在要做的事情是在 FASF 配置中自动覆盖掉 base 配置中的 L1Loss,可以采用如下做法:
loss_bbox=dict(
_delete_=True,
type='IoULoss',
eps=1e-6,
loss_weight=1.0,
reduction='none')
如果没有 _delete_=True
参数,则两个配置会自动合并,L1Loss
中的其他参数始终会保留,无法删除,这肯定是不正确的( IoULoss
中不需要 L1Loss
的初始化参数),现在通过引入 _delete_
保留字则可以实现忽略 base 相关配置,直接采用新配置文件字段功能。
pretty_text 函数可以将字典内容按照 PEP8 格式打印,输出结构清晰,非常好看,如下所示:
# 直接打印字典内容
print(cfg._cfg_dict)
# 输出
{'item1': [1, 2], 'item2': {'a': 1, 'b': 0.1}, 'item3': False, 'item4': 'test'}
# pretty 打印字典内容
print(cfg.pretty_text)
# 输出
item1 = [1, 2]
item2 = dict(a=1, b=0.1)
item3 = False
item4 = 'test'
上述功能是解决第三方库 yapf 实现。而 dump 功能就是将 cfg 内容保存,当想查看实验配置是否正确、查看实验记录以及复现以前实验结果时候非常有用。
见 https://zhuanlan.zhihu.com/p/346203167
Registry 用于提供全局类注册器功能
https://zhuanlan.zhihu.com/p/355271993
Registry 类可以提供一种完全相似的对外装饰函数来管理构建不同的组件,例如 backbones、head 和 necks 等等,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。
例如在 Faster R-CNN 的 backbone 模块实例化时,可以采用如下配置:
backbone=dict(
type='ResNet', # 待实例化的类名
depth=50, # 后面的都是对于的类初始化参数
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
(1) 最简实现
# 方便起见,此处并未使用类方式构建,而是直接采用全局变量
_module_dict = dict()
# 定义装饰器函数
def register_module(name):
def _register(cls):
_module_dict[name] = cls
return cls
return _register
# 装饰器用法
@register_module('one_class')
class OneTest(object):
pass
@register_module('two_class')
class TwoTest(object):
pass
进行简单测试:
if __name__ == '__main__':
# 通过注册类名实现自动实例化功能
one_test = _module_dict['one_class']()
print(one_test)
# 输出
<__main__.OneTest object at 0x7f1d7c5acee0>
可以发现只要将所定义的简单装饰器函数作用到类名上,然后内部采用 _module_dict
保存信息即可
(2) 实现无需传入参数,自动根据类名初始化类
_module_dict = dict()
def register_module(module_name=None):
def _register(cls):
name = module_name
# 如果 module_name 没有给,则自动获取
if module_name is None:
name = cls.__name__
_module_dict[name] = cls
return cls
return _register
@register_module('one_class')
class OneTest(object):
pass
@register_module()
class TwoTest(object):
pass
进行简单测试:
if __name__ == '__main__':
one_test = _module_dict['one_class']
# 方便起见,此处仅仅打印了类对象,而没有实例化。如果要实例化,只需要 one_test() 即可
print(one_test)
two_test = _module_dict['TwoTest']
print(two_test)
# 输出
<class '__main__.OneTest '>
<class '__main__.TwoTest'>
(3) 实现重名注册强制报错功能
def register_module(module_name=None):
def _register(cls):
name = module_name
if module_name is None:
name = cls.__name__
# 如果重名注册,则强制报错
if name in _module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {name}')
_module_dict[name] = cls
return cls
return _register
新增一个 force 参数即可
def register_module(module_name=None,force=False):
def _register(cls):
name = module_name
if module_name is None:
name = cls.__name__
# 如果重名注册,则强制报错
if not force and name in _module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {name}')
_module_dict[name] = cls
return cls
return _register
测试:
@register_module('one_class')
class OneTest(object):
pass
@register_module('one_class',True)
class TwoTest(object):
pass
if __name__ == '__main__':
one_test = _module_dict['one_class']
print(one_test)
# 输出
<class '__main__.TwoTest'>
(5) 实现直接注册类功能
实现直接注册类的功能,只需要 _module_dict['name'] = module_class
即可。
class Registry:
def __init__(self, name):
# 可实现注册类细分功能
self._name = name
# 内部核心内容,维护所有的已经注册好的 class
self._module_dict = dict()
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
# 最核心代码
self._module_dict[module_name] = module_class
# 装饰器函数
def register_module(self, name=None, force=False, module=None):
if module is not None:
# 如果已经是 module,那就知道 增加到字典中即可
self._register_module(
module_class=module, module_name=name, force=force)
return module
# 最标准用法
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
在 MMCV 中所有的类实例化都是通过 build_from_cfg
函数实现,做的事情非常简单,就是给定 module_name
,然后从 self._module_dict
提取即可。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type') # 注册 str 类名
if is_str(obj_type):
# 相当于 self._module_dict[obj_type]
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
# 如果已经实例化了,那就直接返回
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# 最终初始化对于类,并且返回,就完成了一个类的实例化过程
return obj_cls(**args)
一个完整的使用例子如下:
CONVERTERS = Registry('converter')
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)
处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)
在我们熟知的 pytorch 中某个 tensor 或者 module 都有 register_hook(hook_fn)
函数,通过注册 hook,可以拦截和修改某些中间变量的值。
在 python 中要实现 hook 机制,非常简单,传入一个函数即可,如下是一个简单的 hook,该 hook 的功能是打印内部变量
def hook(d):
print(d)
def add(a,b,c,hook_fn=None)
sum1=a+b
if hook_fn is not None:
hook_fn(sum1)
return sum1+c
# 调用
add(1,2,3,hook)
在 PyTorch 中提供了非常方便的注册机制,用户可以随意插入任何函数来捕获中间过程,下面是一个简单的示例
import torch
from torch import nn
from mmcv.cnn import constant_init
# hook 函数,其三个参数不能修改(参数名随意),本质上是 PyTorch 内部回调函数
# module 本身对象
# input 该 module forward 前输入
# output 该 module forward 后输出
def forward_hook_fn(module, input, output):
print('weight', module.weight.data)
print('bias', module.bias.data)
print('input', input)
print('output', output)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(3, 1)
self.fc.register_forward_hook(forward_hook_fn)
constant_init(self.fc, 1)
def forward(self, x):
o = self.fc(x)
return o
运行输出:
if __name__ == '__main__':
model = Model()
x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)
# 输出
weight:tensor([[1., 1., 1.]])
bias: tensor([0.])
input: (tensor([[0., 1., 2.]]),)
output:tensor([[3.]], grad_fn=)