视觉数据标注需要的代码,批量重命名文件夹中的图片文件, 图像保存格式转换, 去除相似度高的图片

1. 批量重命名文件夹中的图片文件

# -*- coding:utf8 -*-

import os


class BatchRename():
    '''
    批量重命名文件夹中的图片文件
    '''

    def __init__(self):
        self.path = '/home/zhy/Documents/Perception/camera_data/mine_obstacle_image/image/6/'  # 表示需要命名处理的文件夹

    def rename(self):
        filelist = os.listdir(self.path)  # 获取文件路径
        filelist.sort()
        total_num = len(filelist)  # 获取文件长度(个数)
        i = 3172# 表示文件的命名是从1开始的
        for item in filelist:
            if item.endswith('.jpg'):  # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
                src = os.path.join(os.path.abspath(self.path), item)
                #dst = os.path.join(os.path.abspath(self.path), '' + str(i) + '.txt')  # 处理后的格式也为jpg格式的,当然这里可以改成png格式
                dst = os.path.join(os.path.abspath(self.path),
                                   '00' + format(str(i), '0>2s') + '.jpg')  # 这种情况下的命名格式为0000000.jpg形式,可以自主定义想要的格式
                try:
                    os.rename(src, dst)
                    print('converting %s to %s ...' % (src, dst))
                    i = i + 1
                except:
                    continue
        print('total %d to rename & converted %d txt' % (total_num, i))


if __name__ == '__main__':
    demo = BatchRename()
    demo.rename()

2.图像保存格式转换

# -*- coding:utf8 -*-
import os
import cv2 as cv

image_path = '/home/zhy/Documents/Perception/标注/data1/images/val'  # 设置图片读取的路径
save_path = '/home/zhy/Documents/Perception/标注/data1/images/1'  #  设置图片保存的路径

if not os.path.exists(save_path):  #  判断路径是否正确,如果正确就打开
    os.makedirs(save_path)

image_file = os.listdir(image_path)

for image in image_file:
    if image.split('.')[-1] in ['bmp', 'jpg', 'jpeg', 'png', 'JPG', 'PNG']:
        str = image.rsplit(".", 1)  # 从右侧判断是否有符号“.”,并对image的名称做一次分割。如112345.jpeg分割后的str为["112345","jpeg"]
        output_img_name = str[0] + ".jpg"  # 取列表中的第一个字符串与“.jpg”放在一起。
        src = cv.imread(os.path.join(image_path, image))
        newimg = cv.imwrite(save_path + '/' + output_img_name, src)
print('FINISHED')

3. kitti数据集txt转xml

# -*- coding:utf8 -*-
# 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML
from xml.dom.minidom import Document
import cv2
import os


def generate_xml(name, split_lines, img_size, class_ind):
    doc = Document()  # 创建DOM文档对象
    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)
    title = doc.createElement('folder')
    title_text = doc.createTextNode('KITTI')
    title.appendChild(title_text)
    annotation.appendChild(title)
    img_name = name + '.png'
    title = doc.createElement('filename')
    title_text = doc.createTextNode(img_name)
    title.appendChild(title_text)
    annotation.appendChild(title)
    source = doc.createElement('source')
    annotation.appendChild(source)
    title = doc.createElement('database')
    title_text = doc.createTextNode('The KITTI Database')
    title.appendChild(title_text)
    source.appendChild(title)
    title = doc.createElement('annotation')
    title_text = doc.createTextNode('KITTI')
    title.appendChild(title_text)
    source.appendChild(title)
    size = doc.createElement('size')
    annotation.appendChild(size)
    title = doc.createElement('width')
    title_text = doc.createTextNode(str(img_size[1]))
    title.appendChild(title_text)
    size.appendChild(title)
    title = doc.createElement('height')
    title_text = doc.createTextNode(str(img_size[0]))
    title.appendChild(title_text)
    size.appendChild(title)
    title = doc.createElement('depth')
    title_text = doc.createTextNode(str(img_size[2]))
    title.appendChild(title_text)
    size.appendChild(title)
    for split_line in split_lines:
        line = split_line.strip().split()
        if line[0] in class_ind:
            object = doc.createElement('object')
            annotation.appendChild(object)
            title = doc.createElement('name')
            title_text = doc.createTextNode(line[0])
            title.appendChild(title_text)
            object.appendChild(title)
            bndbox = doc.createElement('bndbox')
            object.appendChild(bndbox)
            title = doc.createElement('xmin')
            title_text = doc.createTextNode(str(int(float(line[4]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('ymin')
            title_text = doc.createTextNode(str(int(float(line[5]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('xmax')
            title_text = doc.createTextNode(str(int(float(line[6]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('ymax')
            title_text = doc.createTextNode(str(int(float(line[7]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
    # 将DOM对象doc写入文件
    f = open('/home/zhy/Documents/Perception/标注/training/Annotations/' + name + '.xml', 'w')
    f.write(doc.toprettyxml(indent=''))
    f.close()


if __name__ == '__main__':
    class_ind = ('Cyclist', 'Car', 'train', 'tipperTruck', 'person')
    # cur_dir=os.getcwd()
    cur_dir = '/home/zhy/Documents/Perception/标注/training'
    labels_dir = os.path.join(cur_dir, 'label_2')
    for parent, dirnames, filenames in os.walk(labels_dir):  # 分别得到根目录,子目录和根目录下文件
        for file_name in filenames:
            full_path = os.path.join(parent, file_name)  # 获取文件全路径
            f = open(full_path)
            split_lines = f.readlines()
            name = file_name[:-4]  # 后四位是扩展名.txt,只取前面的文件名
            img_name = name + '.png'
            img_path = os.path.join('/home/zhy/Documents/Perception/标注/training/image_2/', img_name)  # 路径需要自行修改
            img_size = cv2.imread(img_path).shape
            generate_xml(name, split_lines, img_size, class_ind)
print('all txts has converted into xmls')

4. 将原来的类合并

# -*- coding:utf8 -*-
import os


# modify_annotations_txt.py
#将原来的8类物体转换为我们现在需要的3类:Car,Pedestrian,Cyclist。
#我们把原来的Car、Van、Truck,Tram合并为Car类,把原来的Pedestrian,Person(sit-ting)合并为现在的Pedestrian,原来的Cyclist这一类保持不变。
import glob
import string
txt_list = glob.glob('/home/zhy/Documents/Perception/标注/training/label_txt/*.txt')
def show_category(txt_list):
    category_list= []
    for item in txt_list:
        try:
            with open(item) as tdf:
                for each_line in tdf:
                    labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开
                    category_list.append(labeldata[0]) # 只要第一个字段,即类别
        except IOError as ioerr:
            print('File error:'+str(ioerr))
    print(set(category_list)) # 输出集合
def merge(line):
    each_line=''
    for i in range(len(line)):
        if i!= (len(line)-1):
            each_line=each_line+line[i]+' '
        else:
            each_line=each_line+line[i] # 最后一条字段后面不加空格
    each_line=each_line+'\n'
    return (each_line)
print('before modify categories are:\n')
show_category(txt_list)
for item in txt_list:
    new_txt=[]
    try:
        with open(item, 'r') as r_tdf:
            for each_line in r_tdf:
                labeldata = each_line.strip().split(' ')
                if labeldata[0] == '3': # 合并行人类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'0')
                # if labeldata[0] in ['Truck','Van']: # 合并重卡
                #     labeldata[0] = labeldata[0].replace(labeldata[0],'tipperTruck')
                # if labeldata[0] in ['Person_sitting','Pedestrian']: # 合并行人类
                #     labeldata[0] = labeldata[0].replace(labeldata[0],'person')
                # if labeldata[0] == 'Tram': # 合并行人类
                #     labeldata[0] = labeldata[0].replace(labeldata[0],'train')
                # if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类
                #     labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
                # if labeldata[0] == 'Person_sitting': # 合并行人类
                #     labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian')
                # if labeldata[0] == 'DontCare': # 忽略Dontcare类
                #     continue
                # if labeldata[0] == 'Misc': # 忽略Misc类
                #     continue
                new_txt.append(merge(labeldata)) # 重新写入新的txt文件
        with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去
            for temp in new_txt:
                w_tdf.write(temp)
    except IOError as ioerr:
        print('File error:'+str(ioerr))
print('\nafter modify categories are:\n')
show_category(txt_list)

5.  去除相似度高的图片

# -*- coding:utf8 -*-
import os
import shutil
import cv2

import numpy as np


def calc_similarity(img1_path, img2_path):
    img1 = cv2.imdecode(np.fromfile(img1_path, dtype=np.uint8), -1)
    H1 = cv2.calcHist([img1], [1], None, [256], [0, 256])  # 计算图直方图
    H1 = cv2.normalize(H1, H1, 0, 1, cv2.NORM_MINMAX, -1)  # 对图片进行归一化处理
    img2 = cv2.imdecode(np.fromfile(img2_path, dtype=np.uint8), -1)
    H2 = cv2.calcHist([img2], [1], None, [256], [0, 256])  # 计算图直方图
    H2 = cv2.normalize(H2, H2, 0, 1, cv2.NORM_MINMAX, -1)  # 对图片进行归一化处理
    similarity1 = cv2.compareHist(H1, H2, 0)  # 相似度比较
    print('similarity:', similarity1)
    if similarity1 > 0.98:  # 0.98是阈值,可根据需求调整
        return True
    else:
        return False


# 去除相似度高的图片
def is_image(file_name):
    pass


def filter_similar(dir_path):
    filter_dir = os.path.join(os.path.dirname(dir_path), 'filter_similar')
    if not os.path.exists(filter_dir):
        os.mkdir(filter_dir)
    filter_number = 0
    for root, dirs, files in os.walk(dir_path):
        img_files = files
        # img_files = [file_name for file_name in files if is_image(file_name)]
        filter_list = []
        for index in range(len(img_files))[:-4]:
            if img_files[index] in filter_list:
                continue
            for idx in range(len(img_files))[(index + 1):(index + 5)]:
                img1_path = os.path.join(root, img_files[index])
                img2_path = os.path.join(root, img_files[idx])
                if calc_similarity(img1_path, img2_path):
                    filter_list.append(img_files[idx])
                    filter_number += 1
        for item in filter_list:
            src_path = os.path.join(root, item)
            shutil.move(src_path, filter_dir)
    return filter_number


if __name__ == '__main__':
    test_path = "/home/zhy/Documents/Perception/标注/image/8/"
    # save_path = "/home/zhy/Documents/Perception/标注/clean_image/1/"
    filter_similar(test_path)

6. 数据集的划分

# coding:utf-8

import os
import random
import argparse

parser = argparse.ArgumentParser()
#xml文件的地址,根据自己的数据进行修改 xml一般存放在Annotations下
#parser.add_argument('--xml_path', default='Annotations', type=str, help='input xml label path')
#数据集的划分,地址选择自己数据下的ImageSets/Main
parser.add_argument('--txt_path', default='/home/zhy/Documents/Perception/标注/data/labels', type=str, help='output txt label path')
opt = parser.parse_args()

trainval_percent = 1.0  # 训练集和验证集所占比例。 这里没有划分测试集
train_percent = 0.9     # 训练集所占比例,可自己进行调整
#xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(txtsavepath)
if not os.path.exists(txtsavepath):
    os.makedirs(txtsavepath)

num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)

file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')

for i in list_index:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        file_trainval.write(name)
        if i in train:
            file_train.write(name)
        else:
            file_val.write(name)
    else:
        file_test.write(name)

file_trainval.close()
file_train.close()
file_val.close()
file_test.close()

7. xml转 txt

# -*- coding:utf8 -*-
import os
import xml.etree.ElementTree as ET
import io

find_path = '/home/zhy/Documents/Perception/camera_data/mine_obstacle_image/label/3/'  # xml所在的文件
savepath = '/home/zhy/Documents/Perception/camera_data/mine_obstacle_image/label/4/'  # 保存文件

classes = ['car','lightTruck','person','tipperTruck', 'construction','tricycle','train','bicycle','bus']


class Voc_Yolo(object):
    def __init__(self, find_path):
        self.find_path = find_path

    def Make_txt(self, outfile):
        out = open(outfile, 'w')
        print("创建成功:{}".format(outfile))
        return out

    def Work(self, count):
        # 找到文件路径
        for root, dirs, files in os.walk(self.find_path):
            # 找到文件目录中每一个xml文件
            for file in files:
                # 记录处理过的文件
                count += 1
                # 输入、输出文件定义
                input_file = find_path + file
                outfile = savepath + file[:-4] + '.txt'
                # 新建txt文件,确保文件正常保存
                out = self.Make_txt(outfile)
                # 分析xml树,取出w_image、h_image
                tree = ET.parse(input_file)
                root = tree.getroot()
                size = root.find('size')
                w_image = float(size.find('width').text)
                h_image = float(size.find('height').text)
                # 继续提取有效信息来计算txt中的四个数据
                for obj in root.iter('object'):
                    # 将类型提取出来,不同目标类型不同,本文仅有一个类别->0
                    classname = obj.find('name').text
                    # 如果类别不是对应在我们预定好的class文件中,或difficult==1则跳过
                    if classname not in classes == 1:
                        continue
                    # 通过类别名称找到id
                    cls_id = classes.index(classname)
                    xmlbox = obj.find('bndbox')
                    x_min = float(xmlbox.find('xmin').text)
                    x_max = float(xmlbox.find('xmax').text)
                    y_min = float(xmlbox.find('ymin').text)
                    y_max = float(xmlbox.find('ymax').text)
                    # 计算公式
                    x_center = ((x_min + x_max) / 2 - 1) / w_image
                    y_center = ((y_min + y_max) / 2 - 1) / h_image
                    w = (x_max - x_min) / w_image
                    h = (y_max - y_min) / h_image
                    # 文件写入
                    out.write(
                        str(cls_id) + " " + str(x_center) + " " + str(y_center) + " " + str(w) + " " + str(h) + '\n')
                out.close()
        return count


if __name__ == "__main__":
    data = Voc_Yolo(find_path)
    number = data.Work(0)
    print(number)

你可能感兴趣的:(python,yolov5,xml,txt)