<计算机视觉二> labelme标定的数据转换成yolo训练格式

        上一章讲了如何使用labelme标注自己的数据集,本章将继续将标注的数据转换成网络能够训练的数据格式。首先说明下,适合自己的数据格式才是重要的,本文的数据不代表一定要这么写。有可能你在工作或者实际使用中自己摸索一套习惯用的数据格式,或者在团队有已经有了约定俗称的数据格式,本文大致说下思路和具体实现。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    : create_targets.py
@Time    : 2021/08/19 16:18:57
@Author  : XIA Yan
@Contact : 微信 lingyanlove
@Version : 0.1
@License : Apache License Version 2.0, January 2004
@Language: python3.8
@Desc    :   将labelme生成的json标签 制作成YOLO网络能够训练的格式
             为了简化代码,这里强制保存在 data文件夹中不再增加额外的路径代码
'''

import json
import os
import os.path as osp
import labelme
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
import tqdm
import glob
import json
import argparse


#1 创建标签生成
def create_label(json_path:str):
    '''
    @description:
        遍历所有的labelme json文件,生成一个class.txt标签文件
    @Args:
        json_path :(string) 训练的json文件路径
    @Return:
        None
    '''
    
    assert osp.exists(json_path),f"{json_path}不存在当前目录,请检查运行的根目录!"
    label_list = []     #创建一个list用于存放标签

    json_path = glob.glob(f"{json_path}/*.json")
    num_js    = len(json_path)
    for path in tqdm.tqdm(json_path, total= num_js, desc= "正在生成classes.txt标签文件:"):
        try:
            label_file = labelme.LabelFile(filename = str(path))
        except:
            print(path)
            exit(-1)
        for shape in label_file.shapes:
            #忽略指定标签
            if shape["label"] == "#":
                continue
            elif shape["label"] not in label_list:
                label_list.append(shape["label"])

    with open("classes.txt","w") as fp:
        for line in label_list:
            fp.write(line + "\n")

    print(f"\n保存进classes.txt,有{len(label_list)}个标签:\n{label_list}\n")


#2 创建坐标修正
def protect_label_data(path, x1,y1,x2,y2,image_w,image_h):
    '''
    由于labelme的打标签有bug,出现超过图像的边界的情况
    参数:
        x1,y1,x2,y2 是labelme里生成的标签
        image_w,image_h 是原始图像的宽度和高度
    return
        修改后的左上角右下角坐标 x1,y1,x2,y2 
    '''

    #缩放坐标在合理范围
    x1 = np.clip(x1, 0, image_w)
    y1 = np.clip(y1, 0, image_h)
    x2 = np.clip(x2, 0, image_w)
    y2 = np.clip(y2, 0, image_h)

    x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)
    assert (x1 != x2 and y1 != y2), f"{path}\n出现不合理的坐标值:{x1,y1,x2,y2} ,停止生成训练数据!"
    return x1, y1, x2, y2
    

# 3主要功能函数生成能满足yolo训练的json
def read_labelme_label(json_path, xywh = True, norm_label = True):
    '''
    读取labelme 标签的内容
    xyxy2xywh 是否转换为中心点坐标,默认不转换,因为后续还要操作
    默认输出 基于原来图像尺寸的box框 左上角和右下角坐标
    '''

    label_dict = {}  #创建标签映射  例如  "button":0
    with open("classes.txt","r") as fp:
        for idx, line in enumerate(fp.readlines()):
            line = line.strip()
            if len(line) == 0:  #出现空白行
                continue
            label_dict[line] = idx

    dict_out = {}    #存放能够训练的数据

    os.makedirs("images",exist_ok=True)  
    pathes = Path(json_path).glob("*.json")

    for path in tqdm.tqdm(list(pathes), desc= "正在生成data.json训练数据:"):
        out_coord = []      #创建存放当前图像的所有标签信息
        label_file = labelme.LabelFile(filename = str(path))
        image_name = label_file.imagePath                      #图像名称
        img        = labelme.utils.img_data_to_arr(label_file.imageData)
        img_pil    = Image.fromarray(img.astype("uint8")).convert("RGB")
        # print(f"正在保存图片{image_name}到文件夹images下.....")
        img_pil.save(f"images/{image_name}",quality = 90)

        image_path = osp.join("data/images",image_name)             #图像保存路径
        img_h, img_w,  _   = img.shape
        
        
        for shape in label_file.shapes:                        #遍历shapes找标签索引
            try:  #直接无视需要被忽略的标签
                label_id = label_dict[shape["label"]]
            except:
                continue
            
            points = shape["points"]
            points = np.array(points) 
            x1,y1,x2,y2 = points[0][0], points[0][1], points[1][0], points[1][1]
            x1,y1,x2,y2 = protect_label_data(path, x1, y1, x2, y2, img_w, img_h)

            if xywh:   #采用xywh的方式存放box
                x  = (x1 + x2)/2
                y  = (y1 + y2)/2
                w  = abs(x2 - x1)
                h  = abs(y2 - y1)
                if norm_label:
                    x = np.around((x / img_w), 4)
                    y = np.around((y / img_h), 4)
                    w = np.around((w / img_w), 4)
                    h = np.around((h / img_h), 4)

                out_coord.append([label_id, x, y, w, h])

            else:      #采用xyxy的方式存储box
                x1 = np.around(float(x1),4)
                y1 = np.around(float(y1),4)
                x2 = np.around(float(x2),4)
                y2 = np.around(float(y2),4)
                if norm_label:
                    x1 = np.around((x1 / img_w), 4)
                    y1 = np.around((y1 / img_h), 4)
                    x2 = np.around((x2 / img_w), 4)
                    y2 = np.around((y2 / img_h), 4)

                out_coord.append([label_id, x1, y1, x2, y2])

        dict_out[str(image_path)] = out_coord

    #生成json文件保存
    with (open("data.json","w")) as fp:
        json.dump(dict_out,fp,indent=4)
    print(f"所有标签生成完毕!请检查后训练!")



#4 解析classes.txt文件,返回一个字典
def load_classes(path = "data/classes.txt"):
    '''读取txt的label文件,返回一个标签list
    例如 ['lock', 'slide', 'unlock']
    '''
    with open(path, "r") as fp:
        names = fp.read().split("\n")[:-1]
    return names



#6 删除空标签
def delete_empty(path = "data.json"):
    with open(path, "r") as fp:
        json_dict = json.load(fp)     #读取json中字典 

    del_key = []
    for key, value in json_dict.items():
        if len(value) == 0:
            del_key.append(key)
    print(f"标签为空的记录:{del_key}")
    for i in del_key:
        json_dict.pop(str(i))

    #重新写入
    with open(path, "w") as nfp:
        json.dump(json_dict, nfp, indent = 4)


#7可视化结果,防止标签打错用来核对
def colors():    #每个类别拥有唯一的颜色
    label_str_int = load_classes(path = "classes.txt")

    color = [None] * len(label_str_int)  #有多少个label,对应多少个颜色
    for i in range(len(label_str_int)):
        color[i] = (np.random.randint(low = 125, high=255),  #B
                    np.random.randint(low = 100, high=255),  #G
                    np.random.randint(low = 130, high=255))  #R
    return color


def vis_label(json_path = "data.json", xyxy = False,xywh = False, norm_label = False):

    os.makedirs("vis",exist_ok=True) 
    with open(json_path, "r") as fp:
        fp_info = json.load(fp)
    
    color = colors()  #获取标签颜色

    for img_pth, targets in fp_info.items():
        print(f"正在可视化{img_pth}")
        img_BGR = cv2.imread(osp.join("images",str(Path(img_pth).name)))
        h, w, _ = img_BGR.shape
        targets = np.array(targets).reshape(-1, 5)
        
        for target in targets:
            label_int = int(target[0])
            if xywh and norm_label:
                target[1] *= w    # x
                target[2] *= h    # y
                target[3] *= w    # w
                target[4] *= h    # w
                x1 = target[1] - target[3] /2
                y1 = target[2] - target[4] /2
                x2 = target[1] + target[3] /2
                y2 = target[2] + target[4] /2
            elif xywh and (not norm_label):
                x1 = target[1] - target[3] /2
                y1 = target[2] - target[4] /2
                x2 = target[1] + target[3] /2
                y2 = target[2] + target[4] /2
            elif xyxy and norm_label:
                target[1] *= w    # x1
                target[2] *= h    # y1
                target[3] *= w    # x2
                target[4] *= h    # y2
                x1 = target[1] 
                y1 = target[2] 
                x2 = target[3] 
                y2 = target[4] 
            elif xyxy and (not norm_label):
                x1 = target[1] 
                y1 = target[2] 
                x2 = target[3] 
                y2 = target[4] 
            classes = load_classes(path = "classes.txt")
            txt_info =  classes[label_int]
            font_x, font_y = int(x1), int(y1) -2

            print(x1,y1,x2,y2)
            cv2.putText(img_BGR, str(txt_info), (font_x, font_y), cv2.FONT_HERSHEY_TRIPLEX, 1.2, color[label_int], 2)
            cv2.rectangle(img_BGR, (int(x1), int(y1)), (int(x2), int(y2)),color[label_int],2)
        cv2.imwrite(f"vis/{str(Path(img_pth).name)}",img_BGR)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--jspath', default ="jsonPath" ,type= str, help= '训练json地址')
    opt = parser.parse_args()
    print(opt)

    ##1 json路径
    path = opt.jspath

    ##2 classes.txt标签生成,检查完了在执行read_labelme_label
    create_label(path)

    ##3 生成能够训练的data.json
    read_labelme_label(path,xywh= False, norm_label= False)

    ##4 删除坐标为空的训练记录
    delete_empty(path = "data.json")

    ##6 可视化结果
    #vis_label(xyxy = True, xywh= False, norm_label= False)
    

 该脚本根据你给的json文件路径自动遍历所有分类标签申城一个classes.txt文件,例如存放的是

face

dog

文件说明
|---- jsonPath        lanbelme标记的json文件
|---- images          生成训练的图像训练时候网络模型读取的图像路径
|---- classes.txt     生成的标签,标签就是目标分类从0开始
|---- data.json       生成的json文件保存是一个字典的数据主要用于训练
使用时候需要给定 jsonPath文件 然后运行 python create_targets.py 系统会产生其他的文件

看下data.json文件下存放的是什么

{
    "data/images/sur02174.jpg": [
        [
            0,
            367.5455,
            33.0455,
            464.4545,
            127.6364
        ],
        [
            0,
            315.5909,
            115.1364,
            342.8636,
            164.0
        ],
        [
            0,
            387.0,
            165.0,
            422.4091,
            218.5455
        ],
        [
            0,
            1110.0,
            196.0,
            1144.0,
            230.0
        ],
        [
            1,
            1304.0,
            205.0,
            1345.1364,
            306.0455
        ],
        [
            0,
            762.9545,
            105.8182,
            811.0455,
            151.5
        ],
        [
            1,
            520.0,
            225.0,
            563.3182,
            277.6364
        ]
    ],
    "data/images/2964.jpg": [
        [
            1,
            478.4839,
            433.1613,
            554.2903,
            689.6129
        ],
        [
            1,
            1004.2903,
            405.7419,
            1123.5085,
            540.3729
        ]
    ],
....
...

..

.


}

 存放的是一个字典 key值为图像的路径 value是一个list第一位是目标的分类,例如你标记的数据有80个分了那么这个值就是0-79中的1个,后面四个值是box框的坐标,分别对应左上角xy和右下角xy的像素坐标。在生成的时候可以使用下面的xywh和norm_label选择是否生成box框的数据是否是xyxy还是xywh并且使不使用归一化操作。

    ##3 生成能够训练的data.json
    read_labelme_label(path,xywh= False, norm_label= False)

注:box框的描述有很多种,有以左上角右下角坐标描述这样的数据称为xyxy数据格式,有使用box框的中心点坐标和宽度高度描述称为cxcywh,另外还有左上角和宽度高度描述的记为xywh。因此不同的描述方式对应数据解析不同,初学者一定要注意这点,可能你在网路训练的时候出现莫名其妙的box框不对大概率和这个有关。

 至此数据转换已经完成,这个脚本中很多路径写死,例如生成的数据标签就叫data.json,虽然这往往不是很工程,但往往很多写死的路径为看的清晰明了。我在看yolact代码和yolov5的时候被其中的各种全局变量和第三方功能库折磨的不能自理。

下一遍将要介绍如何通过pytorch的数据模块成功读取这些数据。

你可能感兴趣的:(计算机视觉,java,服务器,servlet)