mmdetection检测训练和源码解读

源码注释已发布到github:https://github.com/ming71/mmdetection-annotated

目录

简单开始

检测demo

训练自己的数据集

mmdetection

mmcv库


简单开始

发布地址:https://github.com/open-mmlab/mmdetection

安装环境略,README文件有说,注意使用pytorch1.0就行,后面越来越多代码都会迁移到1.0.

该检测工具的特点是模块化封装,利用现有模块搭建自己的网络比较便利,而且提供给用户自己构建模块的通道,自由度比较高;贡献了很多新的算法,很值得学习一下;分布式训练,很多仓库都没有这个。由于其使用了商汤的mmcv库,在阅读代码的时候还有必要参考mmcv文档了解其实现,最好是直接把源码下下來看实现,至于自己写东西用不用就随意了。

检测demo

新建一个文件demo.py,键入官方给的代码就行,我自己的是以下内容,用钩子查看了下网络结构和部分层的输入输出,用这个也能跑出结果,不影响:

import mmcv
import torch
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
from mmdet.apis import inference_detector, show_result
import ipdb

def roialign_forward(module,input,output):
	print('\n\ninput:')
	print(input[0].shape,'\n',input[1].shape)

if __name__ == '__main__':
	params=[]
	def hook(module,input):
		# print('breakpoint')
		params.append(input)
		# print(input[0].shape)
		# data=input
	cfg = mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py')
	cfg.model.pretrained = None

	# ipdb.set_trace()

	# construct the model and load checkpoint
	model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
	print(model)
	handle=model.backbone.conv1.register_forward_pre_hook(hook)
	# model.bbox_roi_extractor.roi_layers[0].register_forward_hook(roialign_forward)
	
	_ = load_checkpoint(model, 'weights/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth')
	
	# test a single image
	img= mmcv.imread('/py/pic/2.jpg')
	result = inference_detector(model, img, cfg)
    
        print(params)	

	show_result(img, result)
	handle.remove()

需要改动的地方为:

  1. mmcv.Config.fromfile('configs/faster_rcnn_r50_fpn_1x.py'):设置为自己想要选择的模型配置文件,configs提供了很多
  2. load_checkpoint:路径改为下载好的权值文件存放的路径,比如我的:'weights/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth'。权值文件要和config路径文件名匹配!
  3. 权值文件下载:https://github.com/open-mmlab/mmdetection/blob/master/MODEL_ZOO.md(注意:网速特别慢几Kb,解决方法是,用手机流量下,可达数百Kb
  4. img = mmcv.imread('test.jpg'):改为自己的测试图片地址

 运行demo.py,就可以查看效果了:

一点分析 

从运行的inference来看,只要更换config和weights文件就能用不同的网络检测。代码很容易发现,因为程序做了封装,build_detector函数搭建模型,inference_detector函数负责推理检测,查看函数内部不难发现,其将不同模块封装成backbone,neck,head等部分,在config中写入,通过读取配置,注册模块,进行封装,然后高级调用搭建网络。目前看来是这样,至于内部具体着怎么实现的后面还要仔细看代码才知道。

训练自己的数据集

使用coco格式进行训练比较方便,先可以使用labelimg标注自己的voc数据集得到xml转化成coco的json format,转换工具在我这个repo:https://github.com/ming71/voc_data_convert

然后创建文件夹,方式如下,其中后三个放图片,第一个存放json文件(命名尽量按这个 ,免得修改程序麻烦):

            mmdetection检测训练和源码解读_第1张图片                           mmdetection检测训练和源码解读_第2张图片

 然后配置configs文件,这里使用的是faster_rcnn_r50_fpn_1x.py,于是可以在其内修改训练测试的数据地址,训练方式存储路径等,也可以都不改,这里用coco数据集,直接跑下面命令就可以了。(运行的时候会自动找到model zoo网站下载resnet-50的backbone参数和模型)

python tools/train.py  configs/faster_rcnn_r50_fpn_1x.py

如果为了方便不想调这些东西,也可以直接用coco数据集,分别将jason和train的图片放进annotations和train2017,执行上述代码即可开始。 

mmdetection

详细注释发布到:https://github.com/ming71/mmdetection-annotated

mmcv库

官方文档:https://mmcv.readthedocs.io/en/latest/

很多地方写的比较草率,对接口描述不详细,参考源码一块读。

 

你可能感兴趣的:(运行记录,计算机视觉)