原先使用MMDetection1.0的时候,为了避免每次训练新模型时都要修改各种配置文件的麻烦,同事写了一个config.xml
文件,并修改了MMDetection的各个文件,使之能读取该配置xml中的相应参数。这样在训练时,如果要修改标签、数据集大小、使用模型、训练次数等参数,可以通过修改xml文件中的相应参数来实现,而不需要进入到MMDetection的各个配置文件里面手动修改,增加了模型训练和测试的便捷性,节省了研发时间。
最近部门统一改用MMDetection2.11,发现与MMDetection1.0略微存在差别,深层次的差别我没有进行探究,最明显的是2.0版本对模型配置文件进行了拆分,但是每个拆分后的文件都还比较眼熟,算是熟面孔。这里参考同事写MMDetection1.0配置xml文件和修改MMDetection配置文件的思路,对MMDetection2.11的配置文件进行修改,使之能实现上述功能。
首先,在config.xml
文件里对所有在训练时会经常修改的参数进行预定义,具体实现方法如下:
将MMDetection2.11解压在mmdetection-master-v211
文件夹中,并在此文件夹下进行编译,安装成功后,创建工程。文件夹名可以任意定义,但是在后面配置文件的编写中,该名称将作为锚点使用,用于不同文件之间互相查找并确定位置,因此一旦确定后,就不应当再进行修改。
在该文件夹下创建program
文件夹,用于存放config.xml
文件和运行程序。
config.xml
文件内容如下:
faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py
256,128
suokou,guzhangac
4
12
0.95
Faster-RCNN
/home/work2/023_X05_P01/400AF/fortest/
参数说明:
--module_type:定义训练和测试时使用的模型,示例中所示的路径在configs文件夹中。
--image_size:定义数据集中的图片在训练和测试时resize的大小(w, h),配置文件中不需要加括号。
--traindata_lable:标签名,在写配置xml时可以暂不定义,后面有程序可以读取数据集中的xml文件,确定数据集中所有的标签,并自动写入该xml中。如果对标签名的顺序有要求,则可以自定义。
--train_batch:定义训练时的batchsize。
--train_epochs:定义训练时的epoch。
--train_percent:定义划分数据集时数据集与验证集比例。
--save_path_name:定义训练结果的保存路径文件夹名,该文件夹将被创建在tools文件夹下。
--predict_imagePath:定义测试文件夹路径,如定义为0,则自动测试验证集。
注:这里只给出了我平时经常会修改的参数,如果还有其他参数修改频率较高,可在config.xml
文件中继续添加参数,并在程序中进行相应修改。
在配置好config.xml
文件后,为了能够读取配置xml中的参数,在自定义程序和修改后的MMDetection配置文件中对相应参数进行设置,需要使用函数实现xml文件的读取和参数输出。该函数定义如下:
def config_parse(path, input_list):
tree = ET.parse(path)
root = tree.getroot()
output = []
for input_i in input_list:
output.append(root.find(input_i).text)
return output
函数功能:给定想要从config.xml
文件中读取的参数标签名,输出查找结果。输入参数以列表形式给出。
如:[pathDirMmdet, save_path_name] = config_parse(path_config, ['pathDirMmdet','save_path_name'])
在program
文件夹下创建edit_label.py
程序,由于与config.xml
文件在同文件夹下,所以不需确定配置文件位置,可直接读取。
程序功能:自动读取数据集中的图片xml,确定所有标签,并写入config.xml
中traindata_lable参数。
除去前面的config_parse()
函数外,该文件还包括两个主要函数:get_labels()
函数,用于读取图片xml并确定标签;和set_labels()
函数,用于向config.xml
文件写入标签名。
代码分别如下:
def get_labels(xml_path):
xml_path_list = os.listdir(xml_path)
labelList = []
for xml_i in xml_path_list:
if 'xml' in xml_i:
file_path = os.path.join(xml_path, xml_i)
tree = ET.parse(file_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
if name not in labelList:
labelList.append(name)
return labelList
def set_labels(path, labelList):
tree = ET.parse(path)
root = tree.getroot()
root.find('traindata_lable').text = (',').join(labelList)
tree.write(path)
print('设置标签为:', (',').join(labelList))
由此,完成了训练和测试标签的自定义。但是此时MMDetection程序还不能读取自定义标签,仍在使用coco或voc数据集中的默认标签,因此还需要修改MMDetection中的一些数据集配置文件,使其标签与自定义标签相一致。
要修改的文件主要包括:
~/mmdet/datasets/voc.py
~/mmdet/datasets/coco.py
~/mmdet/core/evaluation/class_names.py
其中,~
表示主文件夹目录。
由于这几个文件与config.xml
文件不在同级目录,所以读取文件内容时,要先确定config.xml
文件位置,使用:
path_config = os.path.join(sys.argv[0].split('mmdetection-master-v211')[0], 'mmdetection-master-v211', 'programs', 'config.xml')
语句进行查找定位。
以列表形式读到标签内容:
[label_str] = config_parse(path_config, ['traindata_lable'])
label_list = label_str.split(',')
对于class_names.py
文件,修改后的voc_classes()
如下:
def voc_classes():
print('\n【class_names.py -> voc_classes()】', '\npathPyClassVoc:', pathPyClassVoc)
[label_str] = config_parse(path_config, ['traindata_lable'])
label_list = label_str.split(',')
print('voc list_class:', label_list)
print('len(list_class):', len(label_list))
return label_list
下面的coco_classes()
同理。
在class_names.py
中,label_list
本身的类别就是列表,因此无需进行进一步的修改,而在voc.py
与coco.py
中,还需进一步修改。
由于MMDetection中标签名是以元组形式存储的,所以要使用语句CLASSES = tuple(label_list)
将标签列表转化为元组,并且注释掉原本对CLASSES
的定义。
特别注意的是:之前我参考的教程里说,如果只有一个标签,比如数据集中的标签只有“suokou”,则在写CLASSES
的时候,要写为CLASSES=('suokou',)
,即在唯一的标签名后加一个逗号,不然会报错。我尝试了一下,发现确实如此,刚开始不知道原因,后来无意中看到在Python中,如果定义一个元组,二元组中只有一个元素时,元组tuple会被转化为字符串str,而如果在这唯一的元素后增加一个逗号,则可以避免这一情况。具体如下:
>>>a = ('x')
>>> type(a)
<class 'str'>
>>> b = ('x',)
>>> type(b)
<class 'tuple'>
而在MMDetection中,CLASSES
的类型必须是元组,不然则会报错。因此前面所说,单一标签时额外添加的逗号是这个原因。而使用CLASSES = tuple(label_list)
转化后的标签,及时只有单一元素,转化后的也是元组,不会变为字符串,不需要手动在末尾添加逗号。
修改后的class VOCDataset()
如下:
class VOCDataset(XMLDataset):
path_config = os.path.join(sys.argv[0].split('mmdetection-master-v211')[0]
, 'mmdetection-master-v211','programs', 'config.xml')
[label_str] = config_parse(path_config, ['traindata_lable'])
label_list = label_str.split(',')
CLASSES = tuple(label_list)
coco.py
中的CocoDataset()
同理。
在program
文件夹下创建preprocess_data.py
程序。
程序功能:按config.xml
文件中train_percent
参数设定的比例划分训练集与验证集;将voc数据集转为coco数据集,将instances_train2017.json
、instances_val2017.json
、instances_test2017.json
和训练集、验证集图片保存在config.xml
文件中save_path_name
参数设定的保存路径下相应的coco文件夹中。
这部分代码网上参考很多,可以直接拿来用,不做赘述。
以faster_rcnn_r50_fpn_1x_coco.py
为例,打开该文件,发现比起MMDetection1.0, 2.0版本将该模型拆分成四个部分:
_base_ = ['../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
对这四个文件以此进行修改。
该文件中,对网络结构进行了定义,主要需要修改的变量为标签类别数numClass
。
获得numClass
的方法:读取config.xml
文件的标签,求标签列表的长度。值得注意的是,MMDetection2.0中,背景不再单独作为一类,所以numClass
不需在标签数的基础上+1,有几个标签,numClass
就写多少。
其他的一些参数有时也会修改,比如backbone
的type
、rpn_head
下anchor_generator
的scales
和ratios
等,但我的修改频率不高,所以没有单拎出来作为可以在config.xml
中直接修改的变量,如果遇到需要修改这些参数的情况,需要进到faster_rcnn_r50_fpn.py
文件中进行手动修改。
该文件主要定义训练和测试集路径、图像缩放尺寸、训练批大小等。
数据存放的主路径为config.xml
文件中save_path_name
参数定义的路径,数据均按COCO格式存放,不需额外修改。
图像缩放尺寸img_scale
参数、训练批大小batchSize
均由config.xml
定义。
其中,MMDetection要求img_scale
为元组格式,因此读入config.xml
里自定义的trainimage_size
参数后,需使用scaleImgResize = tuple([int(i) for i in imgSize.split(',')])
语句对其进行转化。
batchSize
的定义很简单,只需读入config.xml
的train_batch
参数,再使用int()
将其转化为整数即可。
该文件主要定义学习策略,修改初始学习率、训练次数、优化器类型等。
其中,训练次数通过config.xml
文件进行设置,其他参数暂时不动,使用默认。
该文件主要定义预训练权重,是否从断点继续训练模型等,不需要额外修改。
顺带一提,该文件中resume_from
与load_from
的区别在于:
resume_from
不仅要从checkpoint文件中读取权重也需要得到特定的优化器状态和epoch数目, 用于程序运行过程中中断后继续训练;
而load_from
仅用于加载模型并微调。
在program
文件夹下创建my_train.py
程序。
程序的主要功能是调用系统自带的train.py
训练程序,调用时输入自定义训练参数,包括使用的模型和权重保存路径等,上述参数均由config.xml
文件获得。
运行程序,发现程序报错。由于报错涉及到MMDetection2.0的代码,与my_train.py
程序本身无关,所以放在下一部分进行解决。
在进行上述修改之后,直接运行my_train.py
程序,会发现报如下错误:
Traceback (most recent call last):
File "/home/work1/mmdetection-master-v211/tools/train.py", line 187, in <module>
main()
File "/home/work1/mmdetection-master-v211/tools/train.py", line 89, in main
cfg = Config.fromfile(args.config)
File "/home/work1/mmdetection-master-v211/mmcv/utils/config.py", line 257, in fromfile
use_predefined_variables)
File "/home/work1/mmdetection-master-v211/mmcv/utils/config.py", line 183, in _file2dict
raise KeyError('Duplicate key is not allowed among bases')
KeyError: 'Duplicate key is not allowed among bases'
这时就需要通过debug定位报错位置,在train.py
的cfg = Config.fromfile(args.config)
这行添加断点,进入到~/mmcv/utils/config.py
程序中,再进入cfg_dict, cfg_text = Config._file2dict(filename,use_predefined_variables)
函数,发现报错原因是函数的这部分语句:
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)
检查cfg_dict_list
的内容,发现该变量为长度为4的列表,再检查列表中的内容,发现列表中每一元素均为长度不一的字典,字典的key
比较眼熟,包括'ET'
、'os'
、'sys'
等为了读取config.xml
而自己导入的包;'imgSize'
、'save_path_name'
、'train_epoch'
等在读取config.xml
时定义的中间变量;和'numClass'
、'batchSize'
、'maxEpoch'
等各个模型配置文件中官方定义的模型参数。
由此可以确定,列表中的四个元素分别对应faster_rcnn_r50_fpn_1x_coco.py
文件下_base_
中的四个模型配置文件,即以下四个文件。
_base_ = ['../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
而通过单步运行程序可以发现,程序在通过for c in cfg_dict_list
这行语句遍历列表中的每个元素,即遍历上述四个配置文件中引入的参数时,有部分参数同时出现在了多个配置文件中,导致列表中存储的四个字典中存在重复的key
,触发了报警程序。而正常情况下,由于这四个配置文件设置了不同的参数,彼此之间没有重复,所以不会出现参数重复的情况,因此也就不会报警。
重复的参数包括:
'ET'
、'os'
、'sys'
,在四个配置文件中均import了这三个包config.xml
时的中间变量:path_config
等这时我有了一个大胆的想法:反正多出点变量影响又不大,又没碍着谁,我把这行报警注释掉试试呢?
(先改了再说,只要程序能跑起来,哪管身后洪水滔天(不是))
于是我注释掉了这两行程序:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
很好,确实不报这个错了呢!(要是还报就怪了)
然而这时报了新的错误——
……
File "/home/work1/mmdetection-master-v211/mmcv/utils/config.py", line 418, in pretty_text
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
……
ParseCodeToTree
ast.parse(code)
File "/home/ty/anaconda3/envs/mmdet2/lib/python3.7/ast.py", line 35, in parse
return compile(source, filename, mode, PyCF_ONLY_AST)
File "" , line 1
ET=<module 'xml.etree.ElementTree' from '/home/ty/anaconda3/envs/mmdet2/lib/python3.7/xml/etree/ElementTree.py'>
^
SyntaxError: invalid syntax
看起来还是自己导入的包惹的祸。
但是这次的报错就乱七八糟地报了一堆,进入到各种乱七八糟的程序里,乱七八糟的也看不明白。
考虑到报警里涉及到我自己导入的包,报警原因肯定还是我对模型配置文件的修改,所以我分析问题还是出现在程序在运行时,我导入的包和自定义的参数对后面程序的某些运行过程造成了一些影响。所以还是同样在debug中单步执行config.py
程序,又进入到_file2dict
函数中,程序顺利运行过之后,检查base_cfg_dict
这个变量,通过与修改前的变量进行比对,这样就发现了问题:
我的base_cfg_dict
比程序未修改时的base_cfg_dict
多了key
,和前面重复的变量一样,多的内容主要就是我import的包:'ET'
、'os'
、'sys'
,和在四个配置文件中读取config.xml
时的各种中间变量。
这个也很好理解,因为我只是单纯注释掉了重复变量,而这些我自定义的内容,全都通过base_cfg_dict.update(c)
这行语句加入到base_cfg_dict
字典中,很显然这个字典中的内容在后面的程序中被用到了,程序读取到了我自定义的变量,又没有相应的处理手段,因此报错。
所以我又有了一个大胆的想法:我把这些我自定义的变量都从这个字典里简单直接粗暴地删掉,让字典恢复如初呢?
于是我仔细比对了修改前后的base_cfg_dict
内容,对自己新加入的键值十分暴力地进行删除:
del base_cfg_dict['ET'], base_cfg_dict['os'], base_cfg_dict['sys']\
, base_cfg_dict['config_parse'], base_cfg_dict['path_config']\
, base_cfg_dict['label_str'], base_cfg_dict['label_list'], base_cfg_dict['save_path_name']\
, base_cfg_dict['img_size'], base_cfg_dict['train_batch'], base_cfg_dict['train_epochs']
“啪”的一下,很快啊!问题解决了!
程序跑起来了!
简直喜闻乐见!大快人心!普天同庆!奔走相告!
这一刻,我的脑海中不禁出现了这张图:
然鹅不行,我还得搞个程序来实现测试。
在program
文件夹下创建infer_test.py
程序。
实现功能:设置测试时的batchsize
,根据config.xml
中设置的测试路径,按批测试图片;并将测试结果(包括图片和标签、置信度)保存在~/result
文件夹下。
按批测试功能在TensorFlow相关文章中有很多代码可以参考,毕竟Google API很久之前就支持按批测试了,这里采取的思路是:
计算图片总数除以batchsize
的余数,即为需要填充的图片数,生成相应的纯黑图片,保存在文件夹中。
list_name_test_img = os.listdir(pathDirTestImg) # pathDirTestImg为测试路径,默认路径下全为图片
num_add = batchSize - len(list_name_test_img) % batchSize
for i in range(num_add):
img_ = np.zeros(img_size, dtype=np.uint8)
path_img_ = os.path.join(pathDirTestImg, '{a}.jpg'.format(a=i))
cv2.imwrite(path_img_, img_)
list_name_test_img.append('{a}.jpg'.format(a=i))
按张读取图片,凑够一个batch后扔给MMDetection一起测试
list_test_img_batch = []
list_test_img_name_batch = []
for i in range(len(list_name_test_img)):
path_test_img = os.path.join(pathDirTestImg, list_name_test_img[i])
img = cv2.imread(path_test_img, cv2.IMREAD_COLOR)
list_test_img_batch.append(img)
list_test_img_name_batch.append(list_name_test_img[i])
if len(list_test_img_batch) == batchSize:
list_result_inf = inference_detector(model, list_test_img_batch)
for j in range(len(list_result_inf)): # 单张图片测试结果
img_src = list_test_img_batch[j] # 单张图片内容
for k in range(len(class_dic)): # 单张图片里单一标签结果
scores = list_result_inf[j][k][:, -1] # 单张图片里单一标签检出的所有置信度
for jj in range(len(scores)): # 单张图片里单一标签检出的某一置信度
if scores[jj] > thersMinConfidence: # thersMinConfidence:最小报警阈值
# cv2.rectangle、cv2.putText等图片处理代码
cv2.imwrite(os.path.join(result_path, list_test_img_name_batch[j]), list_test_img_batch[j])
list_test_img_batch.clear()
list_test_img_name_batch.clear()
写了很多for循环,习惯不太好,但是反正数据量也不大,就凑活着用了(躺)
这样,就基本实现了通过xml文件对部分参数进行快速设置的MMDetection2.0的训练和测试功能。