在训练神经网络时,经常需要使用到某些数据制作、读取的代码,这里做个总结。方便以后使用。
我制作的数据集,input为已经处理好的txt文件,label为xml文件,在保存时将它们保存为同名文件:
先在相应的位置创建两个txt文件:train.txt、test.txt。然后运行以下代码。
以下代码为将数据按照4:1的比例切割为训练集和测试集。(de=0.8设置训练集占总数据的比例)
import os
import random
xmlfilepath = "C:/radardata/11195_2020_10_30_01_04_28/savadata/xmls"
txtsavepath = "C:/radardata/11195_2020_10_30_01_04_28/savadata"
total_xml = os.listdir(xmlfilepath)
num = len(total_xml)
list = range(num)
de = 0.8
zhong = int(num * de)
trainval = random.sample(list, num)
ftrain = open(txtsavepath + '//train.txt', 'w')
ftest = open(txtsavepath + '//test.txt', 'w')
for i in list:
name = total_xml[i][:-4] + '\n'
if i < zhong:
ftrain.write(name)
else:
ftest.write(name)
ftrain.close()
ftest.close()
这样就拥有了两个txt文件:
txt文件中的格式如下所示:
读取时readline获取到每一行,分别加上后缀".txt"和".xml"即可读取。
这种数据集方式适用于语义分割。在实现样本和标签的数量和名称相同且一一对称。读取时只需要读取txt文件中每张图片的名称就可以同时获得对应的样本和标签。
先根据xml文件中的对应各个类别制作classes.txt文件。
再根据文件位置运行以下代码:
import os
import xml.etree.ElementTree as ET
#获取所有类别
def get_classes(classes_path):
classes = os.path.expanduser(classes_path)
with open(classes) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
#读取xml文件得到所有标注框
def convert_annotation(image_id, list_file):
in_file = open(os.path.join(xml_path,'%s.xml'%(image_id)),"rb")
tree=ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
cls = obj.find('name').text
cls_id = classes.index(cls) # 获取类别id
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
list_file.write('\n')
#获取所有类别
classes = get_classes('classes.txt')
#文件保存路径
list_file = open(r'C:\radardata\11195_2020_10_30_01_04_28\savadata\labels.txt', 'w')
#图片和xml路径
image_path = r'C:\radardata\11195_2020_10_30_01_04_28\savadata\imgs'
xml_path = r'C:\radardata\11195_2020_10_30_01_04_28\savadata\xmls'
for file in os.listdir(image_path):
image_id = os.path.splitext(file)[0] #获取文件名
list_file.write(os.path.join(image_path, file)) #写图片完整路径
convert_annotation(image_id, list_file) #写标注框
list_file.close()
label.txt文件中的数据格式为:
在调用读取时的函数为:
from random import shuffle
import numpy as np
from PIL import Image
with open(r'C:\radardata\11195_2020_10_30_01_04_28\savadata\labels.txt') as f:
lines = f.readlines()
shuffle(lines)
for annotation_line in lines:
line = annotation_line.split()
image = Image.open(line[0])
iw, ih = image.size
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
这种数据集读取方式适用于目标检测,XML的数据保存格式为VOC的数据集格式。
数据集的制作方法参照了以下博客:
python读取数据集并生成txt文件_绿柳山庄赵公子的博客-CSDN博客