通过配置xml文件对MMDetection2.0部分参数进行快速设置并对其进行训练及测试的方法

通过配置xml文件快速设置mmd2.0并训练及测试的方法

  • 引言
  • 一、config.xml文件设置
  • 二、读取xml文件语句
  • 三、定义标签:edit_label.py
  • 四、修改MMDetection数据集配置文件
    • 修改class_names.py
    • 修改voc.py和coco.py
  • 五、数据集预处理:preprocess_data.py
  • 六、修改MMDetection模型配置文件
    • faster_rcnn_r50_fpn.py
    • coco_detection.py
    • schedule_1x.py
    • default_runtime.py
  • 七、训练模型:my_train.py
  • 八、解决程序报错
    • 报错1
    • 报错2
  • 九、测试程序:infer_test.py


引言

原先使用MMDetection1.0的时候,为了避免每次训练新模型时都要修改各种配置文件的麻烦,同事写了一个config.xml文件,并修改了MMDetection的各个文件,使之能读取该配置xml中的相应参数。这样在训练时,如果要修改标签、数据集大小、使用模型、训练次数等参数,可以通过修改xml文件中的相应参数来实现,而不需要进入到MMDetection的各个配置文件里面手动修改,增加了模型训练和测试的便捷性,节省了研发时间。

最近部门统一改用MMDetection2.11,发现与MMDetection1.0略微存在差别,深层次的差别我没有进行探究,最明显的是2.0版本对模型配置文件进行了拆分,但是每个拆分后的文件都还比较眼熟,算是熟面孔。这里参考同事写MMDetection1.0配置xml文件和修改MMDetection配置文件的思路,对MMDetection2.11的配置文件进行修改,使之能实现上述功能。


一、config.xml文件设置

首先,在config.xml文件里对所有在训练时会经常修改的参数进行预定义,具体实现方法如下:

将MMDetection2.11解压在mmdetection-master-v211文件夹中,并在此文件夹下进行编译,安装成功后,创建工程。文件夹名可以任意定义,但是在后面配置文件的编写中,该名称将作为锚点使用,用于不同文件之间互相查找并确定位置,因此一旦确定后,就不应当再进行修改。

在该文件夹下创建program文件夹,用于存放config.xml文件和运行程序。

config.xml文件内容如下:


	faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
	256,128
	suokou,guzhangac
	4
	12
	0.95
	Faster-RCNN
	/home/work2/023_X05_P01/400AF/fortest/

参数说明:
--module_type:定义训练和测试时使用的模型,示例中所示的路径在configs文件夹中。
--image_size:定义数据集中的图片在训练和测试时resize的大小(w, h),配置文件中不需要加括号。
--traindata_lable:标签名,在写配置xml时可以暂不定义,后面有程序可以读取数据集中的xml文件,确定数据集中所有的标签,并自动写入该xml中。如果对标签名的顺序有要求,则可以自定义。
--train_batch:定义训练时的batchsize。
--train_epochs:定义训练时的epoch。
--train_percent:定义划分数据集时数据集与验证集比例。
--save_path_name:定义训练结果的保存路径文件夹名,该文件夹将被创建在tools文件夹下。
--predict_imagePath:定义测试文件夹路径,如定义为0,则自动测试验证集。

注:这里只给出了我平时经常会修改的参数,如果还有其他参数修改频率较高,可在config.xml文件中继续添加参数,并在程序中进行相应修改。

二、读取xml文件语句

在配置好config.xml文件后,为了能够读取配置xml中的参数,在自定义程序和修改后的MMDetection配置文件中对相应参数进行设置,需要使用函数实现xml文件的读取和参数输出。该函数定义如下:

def config_parse(path, input_list):
	tree = ET.parse(path)
    root = tree.getroot()
    output = []
    for input_i in input_list:
    	output.append(root.find(input_i).text)
   	return output

函数功能:给定想要从config.xml文件中读取的参数标签名,输出查找结果。输入参数以列表形式给出。
如:[pathDirMmdet, save_path_name] = config_parse(path_config, ['pathDirMmdet','save_path_name'])

三、定义标签:edit_label.py

program文件夹下创建edit_label.py程序,由于与config.xml文件在同文件夹下,所以不需确定配置文件位置,可直接读取。
程序功能:自动读取数据集中的图片xml,确定所有标签,并写入config.xml中traindata_lable参数。
除去前面的config_parse()函数外,该文件还包括两个主要函数:get_labels()函数,用于读取图片xml并确定标签;和set_labels()函数,用于向config.xml文件写入标签名。
代码分别如下:

def get_labels(xml_path):
	xml_path_list = os.listdir(xml_path)
    labelList = []
    for xml_i in xml_path_list:
        if 'xml' in xml_i:
        file_path = os.path.join(xml_path, xml_i)
        tree = ET.parse(file_path)
        root = tree.getroot()
        for obj in root.findall('object'):
        	name = obj.find('name').text
            if name not in labelList:
            	labelList.append(name)
   	return labelList
def set_labels(path, labelList):
    tree = ET.parse(path)
    root = tree.getroot()
    root.find('traindata_lable').text = (',').join(labelList)
    tree.write(path)
    print('设置标签为:', (',').join(labelList))

由此,完成了训练和测试标签的自定义。但是此时MMDetection程序还不能读取自定义标签,仍在使用coco或voc数据集中的默认标签,因此还需要修改MMDetection中的一些数据集配置文件,使其标签与自定义标签相一致。

四、修改MMDetection数据集配置文件

要修改的文件主要包括:
~/mmdet/datasets/voc.py
~/mmdet/datasets/coco.py
~/mmdet/core/evaluation/class_names.py
其中,~表示主文件夹目录。

由于这几个文件与config.xml文件不在同级目录,所以读取文件内容时,要先确定config.xml文件位置,使用:
path_config = os.path.join(sys.argv[0].split('mmdetection-master-v211')[0], 'mmdetection-master-v211', 'programs', 'config.xml')
语句进行查找定位。

以列表形式读到标签内容:

[label_str] = config_parse(path_config, ['traindata_lable'])
label_list = label_str.split(',')

修改class_names.py

对于class_names.py文件,修改后的voc_classes()如下:

def voc_classes():
	print('\n【class_names.py -> voc_classes()】', '\npathPyClassVoc:', pathPyClassVoc)
    [label_str] = config_parse(path_config, ['traindata_lable'])
    label_list = label_str.split(',')
    print('voc list_class:', label_list)
    print('len(list_class):', len(label_list))
    return label_list

下面的coco_classes()同理。

class_names.py中,label_list本身的类别就是列表,因此无需进行进一步的修改,而在voc.pycoco.py中,还需进一步修改。

修改voc.py和coco.py

由于MMDetection中标签名是以元组形式存储的,所以要使用语句CLASSES = tuple(label_list)将标签列表转化为元组,并且注释掉原本对CLASSES的定义。

特别注意的是:之前我参考的教程里说,如果只有一个标签,比如数据集中的标签只有“suokou”,则在写CLASSES的时候,要写为CLASSES=('suokou',),即在唯一的标签名后加一个逗号,不然会报错。我尝试了一下,发现确实如此,刚开始不知道原因,后来无意中看到在Python中,如果定义一个元组,二元组中只有一个元素时,元组tuple会被转化为字符串str,而如果在这唯一的元素后增加一个逗号,则可以避免这一情况。具体如下:

>>>a = ('x')
>>> type(a)
<class 'str'>

>>> b = ('x',)
>>> type(b)
<class 'tuple'>

而在MMDetection中,CLASSES的类型必须是元组,不然则会报错。因此前面所说,单一标签时额外添加的逗号是这个原因。而使用CLASSES = tuple(label_list)转化后的标签,及时只有单一元素,转化后的也是元组,不会变为字符串,不需要手动在末尾添加逗号。

修改后的class VOCDataset()如下:

class VOCDataset(XMLDataset):
	path_config = os.path.join(sys.argv[0].split('mmdetection-master-v211')[0]
	, 'mmdetection-master-v211','programs', 'config.xml')
	[label_str] = config_parse(path_config, ['traindata_lable'])
    label_list = label_str.split(',')
    CLASSES = tuple(label_list)

coco.py中的CocoDataset()同理。

五、数据集预处理:preprocess_data.py

program文件夹下创建preprocess_data.py程序。
程序功能:按config.xml文件中train_percent参数设定的比例划分训练集与验证集;将voc数据集转为coco数据集,将instances_train2017.jsoninstances_val2017.jsoninstances_test2017.json和训练集、验证集图片保存在config.xml文件中save_path_name参数设定的保存路径下相应的coco文件夹中。

这部分代码网上参考很多,可以直接拿来用,不做赘述。

六、修改MMDetection模型配置文件

faster_rcnn_r50_fpn_1x_coco.py为例,打开该文件,发现比起MMDetection1.0, 2.0版本将该模型拆分成四个部分:

_base_ = ['../_base_/models/faster_rcnn_r50_fpn.py',
		  '../_base_/datasets/coco_detection.py',
		  '../_base_/schedules/schedule_1x.py',
		  '../_base_/default_runtime.py'
          ]

对这四个文件以此进行修改。

faster_rcnn_r50_fpn.py

该文件中,对网络结构进行了定义,主要需要修改的变量为标签类别数numClass

获得numClass 的方法:读取config.xml文件的标签,求标签列表的长度。值得注意的是,MMDetection2.0中,背景不再单独作为一类,所以numClass 不需在标签数的基础上+1,有几个标签,numClass 就写多少。

其他的一些参数有时也会修改,比如backbonetyperpn_headanchor_generatorscalesratios等,但我的修改频率不高,所以没有单拎出来作为可以在config.xml中直接修改的变量,如果遇到需要修改这些参数的情况,需要进到faster_rcnn_r50_fpn.py文件中进行手动修改。

coco_detection.py

该文件主要定义训练和测试集路径、图像缩放尺寸、训练批大小等。

数据存放的主路径为config.xml文件中save_path_name参数定义的路径,数据均按COCO格式存放,不需额外修改。

图像缩放尺寸img_scale参数、训练批大小batchSize均由config.xml定义。

其中,MMDetection要求img_scale为元组格式,因此读入config.xml里自定义的trainimage_size参数后,需使用scaleImgResize = tuple([int(i) for i in imgSize.split(',')])语句对其进行转化。

batchSize的定义很简单,只需读入config.xmltrain_batch参数,再使用int()将其转化为整数即可。

schedule_1x.py

该文件主要定义学习策略,修改初始学习率、训练次数、优化器类型等。

其中,训练次数通过config.xml文件进行设置,其他参数暂时不动,使用默认。

default_runtime.py

该文件主要定义预训练权重,是否从断点继续训练模型等,不需要额外修改。

顺带一提,该文件中resume_fromload_from的区别在于:
resume_from不仅要从checkpoint文件中读取权重也需要得到特定的优化器状态和epoch数目, 用于程序运行过程中中断后继续训练;
load_from仅用于加载模型并微调。

七、训练模型:my_train.py

program文件夹下创建my_train.py程序。

程序的主要功能是调用系统自带的train.py训练程序,调用时输入自定义训练参数,包括使用的模型和权重保存路径等,上述参数均由config.xml文件获得。

运行程序,发现程序报错。由于报错涉及到MMDetection2.0的代码,与my_train.py程序本身无关,所以放在下一部分进行解决。

八、解决程序报错

报错1

在进行上述修改之后,直接运行my_train.py程序,会发现报如下错误:

Traceback (most recent call last):
  File "/home/work1/mmdetection-master-v211/tools/train.py", line 187, in <module>
    main()
  File "/home/work1/mmdetection-master-v211/tools/train.py", line 89, in main
    cfg = Config.fromfile(args.config)
  File "/home/work1/mmdetection-master-v211/mmcv/utils/config.py", line 257, in fromfile
    use_predefined_variables)
  File "/home/work1/mmdetection-master-v211/mmcv/utils/config.py", line 183, in _file2dict
    raise KeyError('Duplicate key is not allowed among bases')
KeyError: 'Duplicate key is not allowed among bases'

这时就需要通过debug定位报错位置,在train.pycfg = Config.fromfile(args.config)这行添加断点,进入到~/mmcv/utils/config.py程序中,再进入cfg_dict, cfg_text = Config._file2dict(filename,use_predefined_variables)函数,发现报错原因是函数的这部分语句:

for c in cfg_dict_list:
	if len(base_cfg_dict.keys() & c.keys()) > 0:
		raise KeyError('Duplicate key is not allowed among bases')
	base_cfg_dict.update(c)

检查cfg_dict_list的内容,发现该变量为长度为4的列表,再检查列表中的内容,发现列表中每一元素均为长度不一的字典,字典的key比较眼熟,包括'ET''os''sys'等为了读取config.xml而自己导入的包;'imgSize''save_path_name''train_epoch'等在读取config.xml时定义的中间变量;和'numClass''batchSize''maxEpoch'等各个模型配置文件中官方定义的模型参数。

由此可以确定,列表中的四个元素分别对应faster_rcnn_r50_fpn_1x_coco.py文件下_base_中的四个模型配置文件,即以下四个文件。

_base_ = ['../_base_/models/faster_rcnn_r50_fpn.py',
		  '../_base_/datasets/coco_detection.py',
		  '../_base_/schedules/schedule_1x.py',
		  '../_base_/default_runtime.py'
          ]

而通过单步运行程序可以发现,程序在通过for c in cfg_dict_list这行语句遍历列表中的每个元素,即遍历上述四个配置文件中引入的参数时,有部分参数同时出现在了多个配置文件中,导致列表中存储的四个字典中存在重复的key,触发了报警程序。而正常情况下,由于这四个配置文件设置了不同的参数,彼此之间没有重复,所以不会出现参数重复的情况,因此也就不会报警。

重复的参数包括:

  • import的包:'ET''os''sys',在四个配置文件中均import了这三个包
  • 读取config.xml时的中间变量:path_config

这时我有了一个大胆的想法:反正多出点变量影响又不大,又没碍着谁,我把这行报警注释掉试试呢?

(先改了再说,只要程序能跑起来,哪管身后洪水滔天(不是))

于是我注释掉了这两行程序:

if len(base_cfg_dict.keys() & c.keys()) > 0:
	raise KeyError('Duplicate key is not allowed among bases')

很好,确实不报这个错了呢!(要是还报就怪了)

然而这时报了新的错误——

报错2

……
  File "/home/work1/mmdetection-master-v211/mmcv/utils/config.py", line 418, in pretty_text
    text, _ = FormatCode(text, style_config=yapf_style, verify=True)
……
 ParseCodeToTree
    ast.parse(code)
  File "/home/ty/anaconda3/envs/mmdet2/lib/python3.7/ast.py", line 35, in parse
    return compile(source, filename, mode, PyCF_ONLY_AST)
  File "", line 1
    ET=<module 'xml.etree.ElementTree' from '/home/ty/anaconda3/envs/mmdet2/lib/python3.7/xml/etree/ElementTree.py'>
       ^
SyntaxError: invalid syntax

看起来还是自己导入的包惹的祸。

但是这次的报错就乱七八糟地报了一堆,进入到各种乱七八糟的程序里,乱七八糟的也看不明白。

考虑到报警里涉及到我自己导入的包,报警原因肯定还是我对模型配置文件的修改,所以我分析问题还是出现在程序在运行时,我导入的包和自定义的参数对后面程序的某些运行过程造成了一些影响。所以还是同样在debug中单步执行config.py程序,又进入到_file2dict函数中,程序顺利运行过之后,检查base_cfg_dict这个变量,通过与修改前的变量进行比对,这样就发现了问题:

我的base_cfg_dict比程序未修改时的base_cfg_dict多了key,和前面重复的变量一样,多的内容主要就是我import的包:'ET''os''sys',和在四个配置文件中读取config.xml时的各种中间变量。

这个也很好理解,因为我只是单纯注释掉了重复变量,而这些我自定义的内容,全都通过base_cfg_dict.update(c)这行语句加入到base_cfg_dict字典中,很显然这个字典中的内容在后面的程序中被用到了,程序读取到了我自定义的变量,又没有相应的处理手段,因此报错。

所以我又有了一个大胆的想法:我把这些我自定义的变量都从这个字典里简单直接粗暴地删掉,让字典恢复如初呢?

于是我仔细比对了修改前后的base_cfg_dict内容,对自己新加入的键值十分暴力地进行删除:

del base_cfg_dict['ET'], base_cfg_dict['os'], base_cfg_dict['sys']\
    , base_cfg_dict['config_parse'], base_cfg_dict['path_config']\
    , base_cfg_dict['label_str'], base_cfg_dict['label_list'], base_cfg_dict['save_path_name']\
    , base_cfg_dict['img_size'], base_cfg_dict['train_batch'], base_cfg_dict['train_epochs']

“啪”的一下,很快啊!问题解决了!

程序跑起来了!

简直喜闻乐见!大快人心!普天同庆!奔走相告!

这一刻,我的脑海中不禁出现了这张图:

通过配置xml文件对MMDetection2.0部分参数进行快速设置并对其进行训练及测试的方法_第1张图片
于是我火速备份了工程,写下了这篇文章,并发誓不再动它。

然鹅不行,我还得搞个程序来实现测试。

九、测试程序:infer_test.py

program文件夹下创建infer_test.py程序。

实现功能:设置测试时的batchsize,根据config.xml中设置的测试路径,按批测试图片;并将测试结果(包括图片和标签、置信度)保存在~/result文件夹下。

按批测试功能在TensorFlow相关文章中有很多代码可以参考,毕竟Google API很久之前就支持按批测试了,这里采取的思路是:

  • 计算图片总数除以batchsize的余数,即为需要填充的图片数,生成相应的纯黑图片,保存在文件夹中。

    list_name_test_img = os.listdir(pathDirTestImg)  # pathDirTestImg为测试路径,默认路径下全为图片
    num_add = batchSize - len(list_name_test_img) % batchSize
    for i in range(num_add):
    	img_ = np.zeros(img_size, dtype=np.uint8)
    	path_img_ = os.path.join(pathDirTestImg, '{a}.jpg'.format(a=i))
    	cv2.imwrite(path_img_, img_)
    	list_name_test_img.append('{a}.jpg'.format(a=i))
    
  • 按张读取图片,凑够一个batch后扔给MMDetection一起测试

    list_test_img_batch = []
    list_test_img_name_batch = []
    for i in range(len(list_name_test_img)):
    	path_test_img = os.path.join(pathDirTestImg, list_name_test_img[i])
    	img = cv2.imread(path_test_img, cv2.IMREAD_COLOR)
    	list_test_img_batch.append(img)
    	list_test_img_name_batch.append(list_name_test_img[i])
    	if len(list_test_img_batch) == batchSize:
    		list_result_inf = inference_detector(model, list_test_img_batch)
    		for j in range(len(list_result_inf)):  # 单张图片测试结果
    			img_src = list_test_img_batch[j]  # 单张图片内容
    			for k in range(len(class_dic)):  # 单张图片里单一标签结果
    				scores = list_result_inf[j][k][:, -1] # 单张图片里单一标签检出的所有置信度
                   	for jj in range(len(scores)):  # 单张图片里单一标签检出的某一置信度
                    	if scores[jj] > thersMinConfidence:  # thersMinConfidence:最小报警阈值
                    		# cv2.rectangle、cv2.putText等图片处理代码
    			cv2.imwrite(os.path.join(result_path, list_test_img_name_batch[j]), list_test_img_batch[j])
    		list_test_img_batch.clear()
    		list_test_img_name_batch.clear()
    

写了很多for循环,习惯不太好,但是反正数据量也不大,就凑活着用了(躺)


这样,就基本实现了通过xml文件对部分参数进行快速设置的MMDetection2.0的训练和测试功能。

你可能感兴趣的:(深度学习,python,深度学习)