detectron2利用别的数据集跑通训练的debug过程

在得到了coco格式的数据集之后,想把程序进行debug了解一下整体的运行过程和框架结构。

找到开始训练的文件,里面有配置文件的选择,模型的建立,数据的加载,优化器,损失函数等的设置。

这里我就有个问题了,官网里给的是tools文件夹下的运行train_net.py文件。然而我在一个知乎里面看到

  • 一般使用 tools/plain_train_net.py 来训练模型。
  • 最简单的训练结构是 SimpleTrainer().train()
  • 一般使用的类是 DefaultTrainer().train()

那这里就需要了解一下了,这三种有什么关系。

train_net.py文件的主函数里面调用了Trainer类。Train类继承于多个母类。继承关系和函数所在的文件如下图所示。总的来说就是Trainer继承于DefaultTrainer,DefaultTrainer继承于SimpleTrainer,SimpleTrainer继承于TrainerBase。

detectron2利用别的数据集跑通训练的debug过程_第1张图片

该图取自于 https://blog.csdn.net/maoymATshanghaitech/article/details/104837330

那么就从TrainerBase开始了解。一个这种训练类应该最少应该包含def train()和def run_step(),而run_step()会在train()方法中调用。run_step()是在每一次迭代中得到损失值,进行反向传播等操作。而在TrainerBase中,并没有对run_step()进行详细定义,并且在train()中进行了HOOK类的相关函数操作。所以TrainerBase类只是一个接口类。那么什么是HOOK类和HOOK类的相关操作呢。HOOK类的基函数是HOOKBASE类,该函数也在train_loop.py中。HOOKBASE也是一个接口类,里面定义了四种方法的接口。这四种方法分别在迭代开始前,迭代开始后,每一个step前和每一个step后,具体的hook类和对应的方法在hook.py中。这些方法应该都是一些统计训练时间这种辅助信息的方法。TrainerBase中还有一个注册hook的方法,这个注册就是给一个空list,然后把你要进行的hook类加入到list中,其实就是告诉代码你要进行哪些hook操作。注册方法中,有一行代码,h.trainer = weakref.proxy(self),h是一种hook类的实例。这个方法好像是将本trainer实例赋给h.trainer,使得hook中的方法能够获得该trainer训练过程中的结果。

我们提到TrainerBase中没有定义def run_step(),在SimpleTrainer中定义了def run_step()。SimpleTrainer已经是一个功能比较齐全的类了,在init()方法中,将model、data_loader、optimizer传入到该trainer类,然后在run_step()进行每一次训练的循环。然后使用_write_metrics()函数将每个step的训练结果存在storage中。

那么DefaultTrainer相对与SimpleTrainer又有哪些改进和不同呢?增加了创建model,optimizer,schedular,dataloader的操作,可以进行checkpoint的load,然后提供了一些比较公用的hook。DefaultTrainer().train(),它包含一个人们可能希望选择的更多标准默认行为。这也意味着它不太可能支持你在研究过程中可能想要的一些非标准行为。

那么plain_train_net.py又是什么呢?与train_net.py有什么区别呢?

plain_train_net.py与train_net.py类似,但是使用的是training loop而不是Trainer,适合自定义训练网络.

train_net.py一个用于训练detectron2内置模型的训练脚本示例。如果你想自定义某些训练逻辑,就可以按照train_net.py那样重写DefaultTrainer的某些方法,比如你希望使用自定义的mapper进行训练,那么只需要重写build_train_loader方法:

如果你觉得DefaultTrainer满足不了自己的训练需求,更进一步地你可以参考tools/plain_train_net.py实现自己的某些策略。

那么我们现在就使用train_net.py进行debug。因为coco数据集太大,我们只对两张voc格式的图像进行了转换,那么我们应该对这两张图像组成的数据集进行注册。虽然之前有个博客记录了关于注册的相关代码,但是现在看有点忘了,这里再重新梳理一下。

直接在在官方的tools/train_net.py上加上注册数据集部分。其中,DATASET_CATEGORIES,种类类别我的只有一种,就只写一个,但是要注意“id”要和json文件中的对应的类别的id一样。然后PREDEFINED_SPLITS_DATASET中要有训练集和验证集(测试集)。我一开始只有训练集,因为我只想看训练的过程,但是DefaultTrainer中的hook应该有关于验证集(测试集)的评估,所以需要有验证集(测试集),不然系统会出现bug。cfg.MODEL.ROI_HEADS.NUM_CLASSES这里指的是前景的类别树,修改为1。我用的是pycharm进行调试,想进行断点调试,为了保证程序能读取配置文件,需要在def setup(args)中表明需要加载的配置文件名。具体可参考

https://blog.csdn.net/weixin_39916966/article/details/103299051?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-4.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-4.nonecase

 

 

 

你可能感兴趣的:(detectron2利用别的数据集跑通训练的debug过程)