TensorFlow 对数据集标记的xml文件解析记录

环境

  • Windows:10
  • Python 3.7.10
  • TensorFlow:2.3
  • matplotlib:3.3.4
  • lxml:4.7.1

最近要用TensorFlow做20种水果识别,对刚入手的数据集,开始对数据集进行检验。

原图如下:
TensorFlow 对数据集标记的xml文件解析记录_第1张图片
以下是通过精灵标注助手生成的xml 文件


<annotation>
<folder>菠萝folder>
<filename>pineapple.jpgfilename>
<path>C:\Users\Desktop\pineapple.jpgpath>
<source>
    <database>Unknowndatabase>
source>
<size>
    <width>730width>
    <height>413height>
    <depth>3depth>
size>

<segmented>0segmented>
    <object>
    <name>菠萝name>
    <pose>Unspecifiedpose>
    <truncated>0truncated>
    <difficult>0difficult>
    <bndbox>
        <xmin>125xmin>
        <ymin>112ymin>
        <xmax>543xmax>
        <ymax>400ymax>
    bndbox>
object>
    <object>
    <name>菠萝name>
    <pose>Unspecifiedpose>
    <truncated>0truncated>
    <difficult>0difficult>
    <bndbox>
        <xmin>547xmin>
        <ymin>97ymin>
        <xmax>721xmax>
        <ymax>390ymax>
    bndbox>
object>
annotation>

安装 matplotlib

pip install matplotlib

安装 lxml

pip install lxml 

通过以下代码将xml中绘画的矩形框显示到图片中。

import tensorflow as tf
import matplotlib.pyplot as plt
from lxml import etree
from matplotlib.patches import Rectangle  # 绘制矩形框

img = tf.io.read_file(r'./pineapple.jpg')

img = tf.image.decode_jpeg(img)  # 对图像进行解码
print(img.shape)
plt.imshow(img)
plt.show()

xml = open(r'./pineapple.xml', encoding='utf-8').read()  # 读取 xml文件
sel = etree.HTML(xml)  # 对 xml 文件进行解析
width = sel.xpath('//size/width/text()')[0]  # 获取图片的宽
height = sel.xpath('//size/height/text()')[0]  # 获取图片的高
bndbox = sel.xpath('//bndbox')
ax = plt.gca()  # 获取当前图像
for i in range(0, len(bndbox)):
    xmin = sel.xpath('//bndbox/xmin/text()')[i]
    ymin = sel.xpath('//bndbox/ymin/text()')[i]
    xmax = sel.xpath('//bndbox/xmax/text()')[i]
    ymax = sel.xpath('//bndbox/ymax/text()')[i]
    xmin = int(xmin)
    ymin = int(ymin)
    xmax = int(xmax)
    ymax = int(ymax)
    plt.imshow(img.numpy())
    rect = Rectangle((xmin, ymin), (xmax - xmin), (ymax - ymin), fill=False, color='red')  # fill=False 不需要填充
    ax.axes.add_patch(rect)  # 添加矩形框
plt.show()

还原出入手的数据集用精灵标注助手标记的效果如下:
TensorFlow 对数据集标记的xml文件解析记录_第2张图片

由于发现数据集中有多边形和矩形框数据混合,所以通过以下代码区分开来
TensorFlow 对数据集标记的xml文件解析记录_第3张图片
以上xml文件一个一个的点开查看比较麻烦,用以下代码进行处理查看:

import os

try:
    import xml.etree.cElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET

# xml文件路径
txt_path = 'C:\\Users\\vvcat\\Desktop\\xxx\\xxxxx\\outputs2\\'

for txt_file in os.listdir(txt_path):
    txt_name = os.path.splitext(txt_file)[0]  #获取文件名
    txt_suffix = os.path.splitext(txt_file)[1]  # 获取后缀
    # print(txt_name, txt_suffix)
    file_name_path = txt_path + txt_name + txt_suffix
    root = ET.parse(file_name_path)
    bndboxs = root.getiterator("bndbox")

    if bndboxs == []:
        print(txt_name + txt_suffix)   # 打印包含多边形框的xml文件

效果如下:
TensorFlow 对数据集标记的xml文件解析记录_第4张图片
打开A(1).xml文件,内容如下:
TensorFlow 对数据集标记的xml文件解析记录_第5张图片
通过以下代码批量将xml中绘画的矩形框显示到图片中,并保存成新的图片。

import tensorflow as tf
import matplotlib.pyplot as plt
from lxml import etree
from matplotlib.patches import Rectangle  # 绘制矩形框
import glob
import os

images = glob.glob('./inputs/*.jpg')
xmls = glob.glob('./outputs/*.xml')
xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in xmls]
images_names = [x.split('\\')[-1].split('.jpg')[0] for x in images]
names = list(set(images_names) & set(xmls_names))
imgs = [img for img in images if img.split('\\')[-1].split('.jpg')[0] in names]  #根据名称排序
imgs.sort(key=lambda x: x.split('\\')[-1].split('.jpg')[0])
xmls.sort(key=lambda x: x.split('\\')[-1].split('.xml')[0])

dstfile = './output_image/'
fpath = os.path.dirname(dstfile)  # 获取文件路径
if not os.path.exists(fpath):
    os.makedirs(fpath)  # 没有就创建路径
images_names = ''

for i in range(0, len(xmls)):
    img = tf.io.read_file(imgs[i])
    img = tf.image.decode_jpeg(img)  # 对图像进行解码
    xml = open(xmls[i], encoding='utf-8').read()  # 读取 xml文件
    sel = etree.HTML(xml)  # 对 xml 文件进行解析
    width = sel.xpath('//size/width/text()')[0]  # 获取图片的宽
    height = sel.xpath('//size/height/text()')[0]  # 获取图片的高
    bndbox = sel.xpath('//bndbox')
    ax = plt.gca()  # 获取当前图像
    for j in range(0, len(bndbox)):
        xmin = sel.xpath('//bndbox/xmin/text()')[j]
        ymin = sel.xpath('//bndbox/ymin/text()')[j]
        xmax = sel.xpath('//bndbox/xmax/text()')[j]
        ymax = sel.xpath('//bndbox/ymax/text()')[j]
        xmin = int(xmin)
        ymin = int(ymin)
        xmax = int(xmax)
        ymax = int(ymax)
        plt.imshow(img.numpy())
        rect = Rectangle((xmin, ymin), (xmax - xmin), (ymax - ymin), fill=False, color='red')  # fill=False 不需要填充
        ax.axes.add_patch(rect)  # 添加矩形框
    images_names = imgs[i].split('\\')[-1]
    plt.savefig(dstfile + images_names)
    # plt.show()
    plt.close()

你可能感兴趣的:(深度学习,tensorflow,xml,python)