TensorFlow 制作自己数据集时,xml转csv千篇一律,把我拐入坑里了。
如果训练自己的数据集只有一个类别,用网络上的xml_to_csv,完全没有问题,源码如下:
# -*- coding: utf-8 -*-
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
def xml_to_csv(path):
xml_list = []
# 读取注释文件
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
value = (root.find('filename').text + '.jpg',
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
xml_list.append(value)
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
# 将所有数据分为样本集和验证集,一般按照3:1的比例
train_list = xml_list[0: int(len(xml_list) * 0.67)]
eval_list = xml_list[int(len(xml_list) * 0.67) + 1: ]
# 保存为CSV格式
train_df = pd.DataFrame(train_list, columns=column_name)
eval_df = pd.DataFrame(eval_list, columns=column_name)
train_df.to_csv('data/train.csv', index=None)
eval_df.to_csv('data/eval.csv', index=None)
def main():
path = './xml'
xml_to_csv(path)
print('Successfully converted xml to csv.')
main()
如果你的类别数据集,超过2类以上,再用上述源码,觉得把所有的数据集3:1的分割,而非一个类别的3:1分割 。
对上述源码略作调整,完美把每一类数据集按照9:1分割为训练数据集和测试数据集,源代码如下:
# coding: utf-8
import glob
import pandas as pd
import xml.etree.ElementTree as ET
classes = ["20Km_h", "no_passing_35", "no_passing", "keep_left", "keep_right", "mandatory", "straight_or_left", "passing_limits",
"bicycles", "pedestrians", "stop", "dangerous"]
def xml_to_csv(path):
train_list = []
eval_list = []
for cls in classes:
xml_list = []
# 读取注释文件
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
if cls == member[0].text:
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
xml_list.append(value)
for i in range(0,int(len(xml_list) * 0.9)):
train_list.append(xml_list[i])
for j in range(int(len(xml_list) * 0.9) + 1,int(len(xml_list))):
eval_list.append(xml_list[j])
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
# 保存为CSV格式
train_df = pd.DataFrame(train_list, columns=column_name)
eval_df = pd.DataFrame(eval_list, columns=column_name)
train_df.to_csv('data/train.csv', index=None)
eval_df.to_csv('data/eval.csv', index=None)
def main():
# path = 'E:\\\data\\\Images'
path = r'D:\work\PycharmPro\trafficsign\SSD_NET\data\xml_data' # path参数更具自己xml文件所在的文件夹路径修改
xml_to_csv(path)
print('Successfully converted xml to csv.')
main()
classes = ["20Km_h", "no_passing_35", "no_passing", "keep_left", "keep_right", "mandatory", "straight_or_left", "passing_limits", "bicycles", "pedestrians", "stop", "dangerous"]
该处需要改为自己数据集类别标签名。
原文:https://blog.csdn.net/miao0967020148/article/details/90208139