# modify_annotations_txt.py
import glob
import string
txt_list = glob.glob('./*.txt') # 存储Labels文件夹所有txt文件路径
def show_category(txt_list):
category_list= []
for item in txt_list:
try:
with open(item) as tdf:
for each_line in tdf:
labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开
category_list.append(labeldata[0]) # 只要第一个字段,即类别
except IOError as ioerr:
print('File error:'+str(ioerr))
print(set(category_list)) # 输出集合
def merge(line):
each_line=''
for i in range(len(line)):
if i!= (len(line)-1):
each_line=each_line+line[i]+' '
else:
each_line=each_line+line[i] # 最后一条字段后面不加空格
each_line=each_line+'\n'
return (each_line)
print('before modify categories are:\n')
show_category(txt_list)
for item in txt_list:
new_txt=[]
try:
with open(item, 'r') as r_tdf:
for each_line in r_tdf:
labeldata = each_line.strip().split(' ')
# if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类
# labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
if labeldata[0] == 'Person_sitting': # 合并行人类
labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian')
if labeldata[0] == 'DontCare': # 忽略Dontcare类
continue
if labeldata[0] == 'Misc': # 忽略Misc类
continue
new_txt.append(merge(labeldata)) # 重新写入新的txt文件
with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去
for temp in new_txt:
w_tdf.write(temp)
except IOError as ioerr:
print('File error:'+str(ioerr))
print('\nafter modify categories are:\n')
show_category(txt_list)
这里忽略了Dontcare、Misc并将Person_sitting类合并到了Pedestrian类,可以按照自己的需求进行修改。
3. 将KITTI的txt标签转化为voc的xml格式:kitti_txt_to_xml.py
# kitti_txt_to_xml.py
# encoding:utf-8
# 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML
from xml.dom.minidom import Document
import cv2
import os
def generate_xml(name,split_lines,img_size,class_ind):
doc = Document() # 创建DOM文档对象
annotation = doc.createElement('annotation')
doc.appendChild(annotation)
title = doc.createElement('folder')
title_text = doc.createTextNode('KITTI')
title.appendChild(title_text)
annotation.appendChild(title)
img_name=name+'.jpg'
title = doc.createElement('filename')
title_text = doc.createTextNode(img_name)
title.appendChild(title_text)
annotation.appendChild(title)
source = doc.createElement('source')
annotation.appendChild(source)
title = doc.createElement('database')
title_text = doc.createTextNode('The KITTI Database')
title.appendChild(title_text)
source.appendChild(title)
title = doc.createElement('annotation')
title_text = doc.createTextNode('KITTI')
title.appendChild(title_text)
source.appendChild(title)
size = doc.createElement('size')
annotation.appendChild(size)
title = doc.createElement('width')
title_text = doc.createTextNode(str(img_size[1]))
title.appendChild(title_text)
size.appendChild(title)
title = doc.createElement('height')
title_text = doc.createTextNode(str(img_size[0]))
title.appendChild(title_text)
size.appendChild(title)
title = doc.createElement('depth')
title_text = doc.createTextNode(str(img_size[2]))
title.appendChild(title_text)
size.appendChild(title)
for split_line in split_lines:
line=split_line.strip().split()
if line[0] in class_ind:
object = doc.createElement('object')
annotation.appendChild(object)
title = doc.createElement('name')
title_text = doc.createTextNode(line[0])
title.appendChild(title_text)
object.appendChild(title)
bndbox = doc.createElement('bndbox')
object.appendChild(bndbox)
title = doc.createElement('xmin')
title_text = doc.createTextNode(str(int(float(line[4]))))
title.appendChild(title_text)
bndbox.appendChild(title)
title = doc.createElement('ymin')
title_text = doc.createTextNode(str(int(float(line[5]))))
title.appendChild(title_text)
bndbox.appendChild(title)
title = doc.createElement('xmax')
title_text = doc.createTextNode(str(int(float(line[6]))))
title.appendChild(title_text)
bndbox.appendChild(title)
title = doc.createElement('ymax')
title_text = doc.createTextNode(str(int(float(line[7]))))
title.appendChild(title_text)
bndbox.appendChild(title)
# 将DOM对象doc写入文件
f = open('Annotations/'+name+'.xml','w')
f.write(doc.toprettyxml(indent = ''))
f.close()
if __name__ == '__main__':
class_ind=('Car', 'Cyclist', 'Truck', 'Van', 'Pedestrian', 'Tram')
cur_dir=os.getcwd()
labels_dir=os.path.join(cur_dir,'label_2')
for parent, dirnames, filenames in os.walk(labels_dir): # 分别得到根目录,子目录和根目录下文件
for file_name in filenames:
full_path=os.path.join(parent, file_name) # 获取文件全路径
f=open(full_path)
split_lines = f.readlines()
name= file_name[:-4] # 后四位是扩展名.txt,只取前面的文件名
img_name=name+'.png'
img_path=os.path.join('./JPEGImages/trian',img_name) # 路径需要自行修改
img_size=cv2.imread(img_path).shape
generate_xml(name,split_lines,img_size,class_ind)
print('all txts has converted into xmls')
# xml_to_yolo_txt.py
# 此代码和VOC_KITTI文件夹同目录
import glob
import xml.etree.ElementTree as ET
# 这里的类名为我们xml里面的类名,顺序现在不需要考虑
class_names = ['Car', 'Cyclist', 'Truck', 'Van', 'Pedestrian', 'Tram']
# xml文件路径
path = './Annotations/'
# 转换一个xml文件为txt
def single_xml_to_txt(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
# 保存的txt文件路径
# txt_file = xml_file.split('.')[0]+'.txt'
txt_file = xml_file.split('.')[0] + '.' + xml_file.split('.')[1] + '.txt'
with open(txt_file, 'w') as txt_file:
for member in root.findall('object'):
#filename = root.find('filename').text
picture_width = int(root.find('size')[0].text)
picture_height = int(root.find('size')[1].text)
class_name = member[0].text
# 类名对应的index
class_num = class_names.index(class_name)
box_x_min = int(member[1][0].text) # 左上角横坐标
box_y_min = int(member[1][1].text) # 左上角纵坐标
box_x_max = int(member[1][2].text) # 右下角横坐标
box_y_max = int(member[1][3].text) # 右下角纵坐标
# 转成相对位置和宽高
x_center = float(box_x_min + box_x_max) / (2 * picture_width)
y_center = float(box_y_min + box_y_max) / (2 * picture_height)
width = float(box_x_max - box_x_min) / picture_width
height = float(box_y_max - box_y_min) / picture_height
print(class_num, x_center, y_center, width, height)
txt_file.write(str(class_num) + ' ' + str(x_center) + ' ' + str(y_center) + ' ' + str(width) + ' ' + str(height) + '\n')
# 转换文件夹下的所有xml文件为txt
def dir_xml_to_txt(path):
for xml_file in glob.glob(path + '*.xml'):
single_xml_to_txt(xml_file)
dir_xml_to_txt(path)
import os
import random
pngfilepath = r'./KITTI_Img'
saveBasePath = r"./"
temp_png = os.listdir(pngfilepath)
total_png = []
for png in temp_png:
if png.endswith(".png"):
total_png.append(png)
train_percent = 0.7
val_percent = 0.15
test_percent = 0.15
num = len(total_png)
# train = random.sample(num,0.9*num)
list = list(range(num))
num_train = int(num * train_percent)
num_val = int(num * val_percent)
train = random.sample(list, num_train)
num1 = len(train)
for i in range(num1):
list.remove(train[i])
val_test = [i for i in list if not i in train]
val = random.sample(val_test, num_val)
num2 = len(val)
for i in range(num2):
list.remove(val[i])
# for i in train:
# if i in list:
# list.remove(i)
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
for i in train:
name = './KITTI_Img/' + total_jpg[i][:-4] + '.png' + '\n'
ftrain.write(name)
for i in val:
name = './KITTI_Img/' + total_jpg[i][:-4] + '.png' + '\n'
fval.write(name)
for i in list:
name = './KITTI_Img/' + total_jpg[i][:-4] + '.png' + '\n'
ftest.write(name)
ftrain.close()
至此数据集处理完毕。
数据集格式需要满足以下格式:images/labels必须在同一级目录下!!!
|——kitti
├── imgages
└── labels
train.txt
val.txt
test.txt
conda create -n yolov5 python=3.7
conda activate yolov5
git clone https://github.com/ultralytics/yolov5
cd yolov5
pip install -r requirements.txt
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: 你的数据集地址/kitti/ # dataset root dir
train: train.txt # train images (relative to 'path') 118287 images
val: val.txt # val images (relative to 'path') 5000 images
test: test.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
# Classes
nc: 6 # number of classes
names: ['Car', 'Cyclist', 'Truck', 'Van', 'Pedestrian', 'Tram'] # class names
nc: 6
parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default=ROOT / 'models/yolov5s.yaml', help='model.yaml path')
parser.add_argument('--data', type=str, default=ROOT / 'data/kitti.yaml', help='dataset.yaml path')
5.训练
python train.py
6.模型评估
python val.py
以下是100个epoh的训练结果:
7. 训练结果:
result:
val_batch0_pred.jpg
val_batch1_pred.jpg
val_batch2_pred.jpg