如果读不懂,是一些语法不太会
该代码基本上用的就是类的多级继承运行训练代码,不直观,
注册函数实现数据不同读取
然后就是hook 钩子函数,添加多个任务事件
该工程主要是 三个reid 最新论文的实现方法,添加了一些训练结构训练技巧,集成到一起,
小白reid不建议这个工程,可读性不是很好, 学习的话,从单个reid 论文方法先动手了解参数结构,然后再看
(菜鸟一枚,如有问题,欢迎批评)
代码结构构建参考
https://github.com/facebookresearch/detectron2
https://zhuanlan.zhihu.com/p/96931265
detectron2(目标检测框架)无死角玩转-06:源码详解(2)-Trainer继承关系,Hook
https://blog.csdn.net/weixin_43013761/article/details/104092658
1、super().init(model, data_loader, optimizer)
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
....
.....
##父类,SimpleTrainer,构造函数, def __init__(self, model, data_loader, optimizer):
##这里子类自己写了构造函数初始化,继承父类的构造函数需要 写 super
super().__init__(model, data_loader, optimizer)
类的继承, 继承后,子类写构造函数__init__ 需要用super 初始化父类构造函数,才可以继承父类,不继承父类构造函数
# default.py 类中, 有数据加载 data_loader = self.build_train_loader(cfg)
# 没有重新定义 __init__(self) 默认继承父类类所有构造函数
# 这里直接跳到 DefaultTrainer类中
class Trainer(DefaultTrainer): #类的继承,这里继承之后又添加了一个成员方法 build_evaluator
@classmethod
def build_evaluator(cls, cfg, num_query, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return ReidEvaluator(cfg, num_query)
2、register 注册函数 读取数据,
代码中有四个装饰器,不改变代码功能的前提下增加函数功能
from …utils.registry import Registry
BACKBONE_REGISTRY = Registry(“BACKBONE”)
BACKBONE_REGISTRY.doc = “”"
通过上述三行注册装饰器函数
self._name
DATASET
====================================>
data root path /home/shiyy/nas/all_workspace/ReID/data
================================>
self._name
META_ARCH
================================>
self._name
BACKBONE
================================>
self._name
HEADS
不同数据的注册函数,得到数据,
自己写一个照着数据写法,注册一下
@ 装饰器函数的功能
模型结构装饰器
from ...utils.registry import Registry
BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
def build_backbone(cfg):
"""
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
Returns:
an instance of :class:`Backbone`
"""
backbone_name = cfg.MODEL.BACKBONE.NAME
#backbone_name = build_resnet_backbone
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg) #打印出来,=build_resnet_backbone
return backbone
# backbone_name 参数的名字是 有 @ 的那一行,函数的名字,调用到resnet网络结构中的 函数
from .build import BACKBONE_REGISTRY #resnet.py 有这句导入,让两个文件连接在一起
@BACKBONE_REGISTRY.register()
def build_resnet_backbone(cfg):
_BASE_: "../Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 751
'''
再 market1501.py 数据处理函数中
@DATASET_REGISTRY.register() #下面是注册的数据类
class Market1501(ImageDataset):
通过注册函数,会调用到这句话
dataset = DATASET_REGISTRY.get(" Market1501")(root=_root, combineall=cfg.DATASETS.COMBINEALL)
开始到 market1501.py 的 Market1501类中,各种类的嵌套,读取数据
'''
DATASETS: #这里的名字需要是,数据处理的类名,被注册的 @DATASET_REGISTRY.register() ,否则找不到
NAMES: (" Market1501",) # from .market1501 import market1501
TESTS: (" Market1501",)
OUTPUT_DIR: "logs/market1501/bagtricks_R50"
register 函数做了什么
def register(self, obj: object = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
函数功能返回 @类名字, 或者 函数的 import 导入路径
例如 数据返回 fastreid.data.datasets.merge_market1501.Mergedata_market1501
例如模型结构返回 build_resnet_backbone (resnet 中的def build_resnet_backbone)
print(func_or_class # #Market1501
#build_resnet_backbone
"""
if obj is None:
# used as a decorator
def deco(func_or_class: object) -> object:
name = func_or_class.__name__ # pyre-ignore
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__ # pyre-ignore
self._do_register(name, obj)
#下面返回的是 注册方式得到 函数方法,和类方法,impprt导入路径
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
777777777777777777777777777777777
3、训练数据,测试数据加载评估 hook 任务事件添加
主要代码在这里面,中间的各种任务是类的多级调用
D:\Projects\reid\fast-reid\fastreid\engine\defaults.py
'''
执行顺序
四个类,mro
class Trainer
(,
,
,
)
顺序执行super 之前的代码,原路返回再执行super 之后的代码
1、super 之前的代码
self.cfg = cfg
logger = logging.getLogger(__name__)
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
setup_logger()
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
logger.info('Prepare training set')
data_loader = self.build_train_loader(cfg) #加载数据接口
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model)
if cfg.MODEL.BACKBONE.NORM == "syncBN":
# Monkey-patching with syncBN
patch_replication_callback(model)
model = model.cuda()
self._hooks = []
2 返回执行super 之后的代码
#
model.train() #model.eval()对应,训练模型,不是 (DefaultTrainer 中 def train (): super().train(self.start_iter, self.max_iter)
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
self.data_loader.dataset,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
if cfg.SOLVER.SWA.ENABLED:
self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
else:
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks()) #建立多个钩子函数任务列表,这里包含了数据测试(测试数据的加载和评估)
print("全部准备就绪准备训练")
3、 trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
#上述已经完成了数据的加载,模型的加载,多个函数功能的实现加载
print("这才是开始训练代码")
return trainer.train() #调用训练代码
'''