(python)xml.etree.ElemenTree 学习

最近需要用到VOC2007格式的数据集,需要自己制作xml文件。但是网上现有的程序都不能很好的运行,因此自学了一下利用python处理xml,在此记录一下。本文参考官网的文档,可以自行前往学习。

1.xml文件格式

首先,来看一下XML所包含的元素类型

1. 标签

2. 属性

3. 数据 1

例如,在下方的test.xml文件中,即为标签,分别表示的起始和终止。 1 2008 141100 4 2011 59900 68 2011 13600

现在我们尝试用python来解析上面的test.xml,建议初学者可以保存下来一起学习。

python处理xml的方式有SAX,DOM和ElementTree:

1.SAX (simple API for XML )

python 标准库包含SAX解析器,SAX用事件驱动模型,通过在解析XML的过程中触发一个个的事件并调用用户定义的回调函数来处理XML文件。

2.DOM(Document Object Model)

将XML数据在内存中解析成一个树,通过对树的操作来操作XML。

3.ElementTree(元素树)

ElementTree就像一个轻量级的DOM,具有方便友好的API。代码可用性好,速度快,消耗内存少。

在本文中主要学习ElementTree的使用,其他两种会在后续文章中给出教程。

2.ElemenTree处理xml文件

2.1解析xml文件

ElementTree相当于将xml当做树状结构来处理,每个data都相当于一个节点dot,tag存于dot.tag,att=attr则以字典的形式{att:attr}存于dot.attrib,data则以dot.text储存。详细可以查看下面代码:

import xml.etree.ElementTree as ET
tree = ET.parse('test.xml')
root = tree.getroot()
print('root.tag:%s' % root.tag)
print('root.attrib:%s' % root.attrib)

输出结果为

root.tag:data
root.attrib:{} #root标签只有标签号,没有属性

查看root下的子节点信息:

for child in root:
    print(child.tag, child.attrib)

输出结果为

country {'name': 'Liechtenstein'}
country {'name': 'Singapore'}
country {'name': 'Panama'}

到这里基本上是没什么难点的,然后再理解一下root和child之间的关系,以及读取在child中的data就能理解xml在ElementTree中的储存形式了,data的读取方式为:

print(root[0][1].text)
print(root[1][2].text)
print(root[2][3].text)

输出为

2008
59900
None

解析完毕。

注意:并非所有的XML输入元素都将作为已解析树的元素。目前,此模块会跳过输入中的所有XML注释,处理指令和文档类型声明。XML文本可以包含注释和处理指令; 它们将在生成XML输出时包含在内。

2.2 处理xml文件

对xml中的元素进行处理,例如搜索特定的元素,或者修改,删除等都在ElementTree.Element类里,详细说明可以看这里,就不一一说明了,直接贴代码,可以自行运行并看结果去理解。

for neighbor in root.iter('neighbor'):
    print(neighbor.attrib)
for neighbor in root.findall('country'): # 找到当前根节点root下的所有一级子节点
    print(neighbor.attrib)
for neighbor in root.find('country'): # 找到第一个节点‘country’下的所有子节点
    print(neighbor.attrib)

然后是修改和删除xml文件中的内容

for rank in root.iter('rank'):
    new_rank = int(rank.text) + 1
    rank.text = str(new_rank)
    rank.set('updated', 'yes')
tree.write('output.xml')

然后可以看到output.xml中如下文档


    
        2
        2008
        141100
        
        
    
    
        5
        2011
        59900
        
    

xml文件中的被修改。

2.3创建xml文件

a = ET.Element('a')
b = ET.SubElement(a, 'b')
c = ET.SubElement(a, 'c')
d = ET.SubElement(c, 'd')
ET.dump(a)

输出结果为

可以看到ElementTree使用起来还是很方便,但是有一点不好的地方,输出的结果并不是标准结构的xml文件,阅读很不方便。在这方面是需要利用xml.dom比较方便。

3.创建VOC2007格式数据集中的xml文件

from xml.dom import minidom
import os
import matplotlib.image as mlp
import random


def VOC():
    make_folder('VOCdevkit')
    make_folder('VOCdevkit/VOC2007')
    make_folder('VOCdevkit/VOC2007/Annotations/')
    make_folder('VOCdevkit/VOC2007/ImageSets')
    make_folder('VOCdevkit/VOC2007/ImageSets/Layout')
    make_folder('VOCdevkit/VOC2007/ImageSets/Main')
    make_folder('VOCdevkit/VOC2007/ImageSets/Segmentation')
    make_folder('VOCdevkit/VOC2007/JPEGImages')


def make_folder(fpath):
    if os.path.exists(fpath):
        pass
    else:
        os.mkdir(fpath)


def easy_node(dom, root, node_name, node_text):
    node = dom.createElement(node_name)
    root.appendChild(node)
    text = dom.createTextNode(str(node_text))
    node.appendChild(text)
    return dom


def load_bbox_data(data_dir):
    file_dir = os.path.join(data_dir, 'ctpn_bbox.txt')
    file_name_arr = []
    bbox_dict = {}
    label_dict = {}
    tmp = None
    with open(file_dir, 'r') as f:
        lines = f.readlines()
        for line in lines:
            arr = line.split(' ')
            bbox_tmp = [arr[2], arr[3], arr[4], arr[5][:-1]]
            if arr[0] == tmp:
                bbox_dict[arr[0]].append(bbox_tmp)
                label_dict[arr[0]].append(arr[1])
            else:
                bbox_dict[arr[0]] = [bbox_tmp]
                label_dict[arr[0]] = [arr[1]]
                file_name_arr.append(arr[0])
                tmp = arr[0]
    return bbox_dict, label_dict, file_name_arr


def make_data_set(trainval_percent, train_percent):
    xmlfilepath = 'VOCdevkit/VOC2007/Annotations'
    txtsavepath = 'VOCdevkit/VOC2007/ImageSets/Main'
    total_xml = os.listdir(xmlfilepath)

    num = len(total_xml)
    list = range(num)
    tv = int(num * trainval_percent)
    tr = int(tv * train_percent)
    trainval = random.sample(list, tv)
    train = random.sample(trainval, tr)

    ftrainval = open(os.path.join(txtsavepath, 'trainval.txt'), 'w')
    ftest = open(os.path.join(txtsavepath, 'test.txt'), 'w')
    ftrain = open(os.path.join(txtsavepath, 'train.txt'), 'w')
    fval = open(os.path.join(txtsavepath, 'val.txt'), 'w')

    for i in list:
        name = total_xml[i][:-4] + '\n'
        if i in trainval:
            ftrainval.write(name)
            if i in train:
                ftrain.write(name)
            else:
                fval.write(name)
        else:
            ftest.write(name)
    ftrainval.close()
    ftrain.close()
    fval.close()
    ftest.close()


if __name__ == '__main__':
    label_path = 'output'
    img_path = 'VOCdevkit/VOC2007/JPEGImages'
    VOC()
    bbox, label, files = load_bbox_data(label_path)
    img_num = len(bbox)
    for file in files:
        output_file = 'VOCdevkit/VOC2007/Annotations/' + file.split('.')[0] + '.xml'
        I = mlp.imread(os.path.join(img_path, file))
        shape = I.shape
        bbox_len = len(bbox[file])
        with open(output_file, 'w') as fh:
            dom = minidom.Document()
            annotation = dom.createElement('annotation')
            dom.appendChild(annotation)
            easy_node(dom, annotation, 'folder', 'VOC2007')
            easy_node(dom, annotation, 'filename', file)
            easy_node(dom, annotation, 'object_num', bbox_len)
            size = dom.createElement('size')
            annotation.appendChild(size)
            easy_node(dom, size, 'width', shape[1])
            easy_node(dom, size, 'height', shape[0])    
            easy_node(dom, size, 'depth', shape[2])
            for i in range(bbox_len):
                objects = dom.createElement('object')
                annotation.appendChild(objects)
                easy_node(dom, objects, 'name', label[file][i])
                easy_node(dom, objects, 'difficult', 0)
                bndbox = dom.createElement('bndbox')
                objects.appendChild(bndbox)
                easy_node(dom, bndbox, 'xmin', bbox[file][i][1])
                easy_node(dom, bndbox, 'ymin', bbox[file][i][0])
                easy_node(dom, bndbox, 'xmax', bbox[file][i][3])
                easy_node(dom, bndbox, 'ymax', bbox[file][i][2])
            dom.writexml(fh, indent='', addindent='\t', newl='\n', encoding='UTF-8')
            print('写入' + output_file)
    print('xml制作完毕')
    make_data_set(trainval_percent=0.7, train_percent=0.8)
    print('数据集分割完毕')

原始数据格式如下:

00000000.jpg text 84 780 131 996
00000000.jpg text 151 586 189 858
00000000.jpg text 211 362 239 563

分别是文件名,类别,bbox的左上和右下两个坐标。

 

你可能感兴趣的:(python学习)