coco2017数据集COCO格式转YOLO格式

coco2017集数据集(.json)训练格式转换成YOLO格式(.txt)

数据集COCO2017数据集

import os 
import json
import random
import argparse

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from matplotlib import patches

"""
COCO 格式的数据集转化为 YOLO 格式的数据集
--json_path 输入的json文件路径
--save_path 保存的文件夹名字,默认为当前目录下的labels。
--image_path 原始的图像存储路径
--image_path_to_txt 将原始图片的路径写入到训练文件txt中
--out_put 输出数据的类比
"""


#val 数据
def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--json_path', default='./instances_val2017.json',type=str, help="input: coco format(json)")
    parser.add_argument('--save_path', default='./labels/valid', type=str, help="specify where to save the output dir of labels")
    parser.add_argument('--image_path', default='./val2017', type=str, help="load images path")
    parser.add_argument('--image_path_to_txt', default='valid.txt', type=str, help="store images path to txt")
    parser.add_argument('--out_put', default='../coco.names', type=str, help="output images names")
    return parser.parse_args()

# # train数据
# def arg_parser():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--json_path', default='./instances_train2017.json',type=str, help="input: coco format(json)")
#     parser.add_argument('--save_path', default='./labels/train', type=str, help="specify where to save the output dir of labels")
#     parser.add_argument('--image_path', default='./train2017', type=str, help="load images path")
#     parser.add_argument('--image_path_to_txt', default='train.txt', type=str, help="store images path to txt")
#     parser.add_argument('--out_put', default='../coco.names', type=str, help="output images names")
#     return parser.parse_args()


#加载图片的类别名字
def load_classes(path="./labels/coco.names"):
    with open(path, "r") as fp:
        names = fp.read().split("\n")[:-1]
    return names

#测试转换的图片是否正确
def test_image_convrt(images_path= "./val2017", label_path="./labels/valid", class_name = './labels/classes.names'):

    # 加载分类
    names = load_classes(path=class_name)

    # 查看文件夹下的图片
    images_path = os.path.abspath(images_path)
    images_name = os.listdir(images_path)

    #随机测试一张图片
    image = random.choice(images_name)
    image_path = images_path + "/" + image

    # 加载图片数据
    img = np.array(cv.imread(image_path))
    H, W, C = img.shape

    # 加载数据标签
    label_path = label_path + "/" + image.split(".")[0] + ".txt"
    boxes = np.loadtxt(label_path, dtype=float).reshape(-1, 5)

    # xywh to xxyy
    boxes[:, 1] = (boxes[:, 1] - boxes[:, 3] / 2) * W
    boxes[:, 2] = (boxes[:, 2] - boxes[:, 4] / 2) * H
    boxes[:, 3] *= W
    boxes[:, 4] *= H

    fig = plt.figure()
    ax = fig.subplots(1)

    for box in boxes:
        bbox = patches.Rectangle((box[1], box[2]), box[3], box[4], linewidth=1,
                                 edgecolor='r', facecolor="none")
    	# 获取图片类别的名字
        label = names[int(box[0])]
        # 加载数据到画布中
        ax.add_patch(bbox)
        # Add label
        plt.text(
            box[1],
            box[2],
            s=label,
            color="white",
            verticalalignment="top",
            bbox={"color": 'g', "pad": 0},
        )
        ax.imshow(img)
    plt.show()

def writejpg2txt(images_path, txt_name):
    # 打开图片列表清单txt文件
    file_name = open(txt_name, "w")
    # 将路径改为绝对路径
    images_path = os.path.abspath(images_path)
    # 查看文件夹下的图片
    images_name = sorted(os.listdir(images_path))

    count = 0
    # 遍历所有文件
    for eachname in images_name:
        # 按照需要的格式写入目标txt文件
        file_name.write(os.path.join(images_path,eachname) + '\n')
        count += 1
    print('生成txt成功!')
    print('{} 张图片地址已写入'.format(count))
    file_name.close()
 
def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = box[0] + box[2] / 2.0
    y = box[1] + box[3] / 2.0
    w = box[2]
    h = box[3]
 
    #round函数确定(xmin, ymin, xmax, ymax)的小数位数
    x = round(x * dw, 6)
    w = round(w * dw, 6)
    y = round(y * dh, 6)
    h = round(h * dh, 6)
    return (x, y, w, h)

if __name__ == '__main__':

    args = arg_parser()
    json_file =   args.json_path # COCO Object Instance 类型的标注
    ana_txt_save_path = args.save_path  # 保存的路径
    image_path = args.image_path
    image_path_to_txt = args.image_path_to_txt
    out_put = args.out_put

    data = json.load(open(json_file, 'r'))
    if not os.path.exists(ana_txt_save_path):
        os.makedirs(ana_txt_save_path)
    
    if not os.path.exists("images"):
        os.makedirs("images")

    image_path_to_txt  = str(os.path.join("./images", image_path_to_txt))

    id_map = {} # coco数据集的id不连续!重新映射一下再输出!
    out = os.path.join(ana_txt_save_path, out_put)
    with open(out, 'w') as f:
        # 写入coco.names
        for i, category in enumerate(data['categories']): 
            f.write(f"{category['name']}\n") 
            id_map[category['id']] = i
    # print(id_map)
 
    anns = {}
    for ann in data['annotations']:
        imgid = ann['image_id']
        anns.setdefault(imgid, []).append(ann)
 
    print('got anns')

    #将图片数据路径写入到txt文本中
    writejpg2txt(image_path, image_path_to_txt)

    for img in tqdm(data['images']):
        filename = img["file_name"]
        img_width = img["width"]
        img_height = img["height"]
        img_id = img["id"]
        head, tail = os.path.splitext(filename)
        ana_txt_name = head + ".txt"  # 对应的txt名字,与jpg一致
        f_txt = open(os.path.join(ana_txt_save_path, ana_txt_name), 'w')
 
        ann_img = anns.get(img_id, [])
        for ann in ann_img:
            box = convert((img_width, img_height), ann["bbox"])
            f_txt.write("%s %s %s %s %s\n" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3]))
        f_txt.close()
    # 验证数据标签是否正确
    test_image_convrt(images_path = image_path, label_path = ana_txt_save_path, class_name=out)

你可能感兴趣的:(笔记,python,计算机视觉,数据分析)