coco格式的json做离线数据增强,随机把图像bbox数据crop贴图,并移除对应的ann


import cv2
import os
import numpy as np
import json
import random
from pycocotools.coco import COCO, maskUtils

ROOT = os.path.dirname(os.path.abspath(__file__))
BASE = os.path.dirname(ROOT)

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

def load_json(filenamejson):
    with open(filenamejson, 'r', encoding='utf-8') as f:
        raw_data = json.load(f)
    return raw_data

def searchDirFile(rootDir, list_, end_suffix):
    for dir_or_file in os.listdir(rootDir):
        filePath = os.path.join(rootDir, dir_or_file)
        if os.path.isfile(filePath):
            if os.path.basename(filePath).endswith(end_suffix):
                list_.append(filePath)
            else:
                continue
        elif os.path.isdir(filePath):
            searchDirFile(filePath, list_, end_suffix)
        else:
            print('not file and dir '+os.path.basename(filePath))


# all_categories = [
#     {
#         "name": "A",
#         "id": 1
#     },
#     {
#         "name": "B",
#         "id": 2
#     },
#     {
#         "name": "dustproof",
#         "id": 3
#     },
#     {
#         "name": "D",
#         "id": 4
#     }
#  ]

path2 = os.path.join(ROOT, "排风堵丢失")
json_list = []
searchDirFile(path2, json_list, ".json")

#labelme中待crop区域的读取
crop_list = []
for m_ind, m_val in enumerate(json_list):
    json_path = m_val
    jpg_path = json_path.replace(".json", ".jpg")
    img = cv2.imread(jpg_path, -1)
    if img is None:
        continue
    if img.ndim==2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    with open(json_path,'r',encoding='utf8')as fp:
        json_data = json.load(fp)
        shapes = json_data['shapes']
        num = len(shapes)
        for n_ind, n_val in enumerate(shapes):
            if n_val['label'] == 'dustproof':
                x1, y1, x2, y2 = n_val['points'][0][0],n_val['points'][0][1],n_val['points'][1][0],n_val['points'][1][1]
                crop_img = img[int(y1):int(y2+1), int(x1):int(x2+1), :]
                crop_list.append(crop_img)

root_data = load_json(os.path.join(BASE, "merge_20201125-20201210-20210105-20210105aug-no-nizi.json"))
coco = COCO(os.path.join(BASE, "merge_20201125-20201210-20210105-20210105aug-no-nizi.json"))
categories = coco.dataset['categories']

m_file = "aug_result_20210618"
mkdir_os(os.path.join(BASE, m_file))

save_dict = {
            'categories': root_data["categories"],
            'images': None,
            'annotations': None, 
            'info': root_data["info"],
            'license': root_data["license"]
            }

new_images_list = []
new_annotations_list = []
new_image_id = 0
new_ann_id = 0
#catIds = coco.getCatIds(catNms=["dustproof"], supNms=[], catIds=[])
#imgIds = coco.getImgIds(catIds=catIds)
imgIds = coco.getImgIds(catIds=[])
for i in range(len(imgIds)):

    # if i >= 100:
    #     continue

    if i % 100 == 0:
        print(i, "/", len(imgIds))
    img_info = coco.loadImgs(imgIds[i])[0]

    #annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
    annIds = coco.getAnnIds(imgIds=img_info['id'])
    anns = coco.loadAnns(annIds)

    other_ann = []
    select_ann = []
    for n_ind, n_val in enumerate(anns):
        if n_val['category_id'] == 3:
            select_ann.append(n_val)
        else:
            other_ann.append(n_val)
    
    if len(select_ann) == 0:
        continue

    cvImage = cv2.imread(os.path.join(BASE, "train2017", img_info['file_name']), -1)

    if cvImage is None:
        print('if cvImage is None:', img_info['file_name'])
        exit()

    if cvImage.ndim == 2:
        cvImage = cv2.cvtColor(cvImage, cv2.COLOR_GRAY2BGR)
    
    img_info['id'] = new_image_id
    img_info['file_name'] = "061801_dustproof-aug_" + img_info['file_name']
    new_image_id += 1
    new_images_list.append(img_info)

    #被增强bbox个数
    aug_num = random.randint(1, len(select_ann))
    #被增强bbox的索引
    aug_num_list = list(range(len(select_ann)))
    be_auged_ind = random.sample(aug_num_list, aug_num)
    #留下的bbox索引
    other_ind = list(set(aug_num_list).difference(set(be_auged_ind)))

    #使用增强crop的bbox的索引
    crop_num_list = list(range(len(crop_list)))
    aug_crop_ind = random.sample(crop_num_list, aug_num)

    for p_ind, p_val in enumerate(be_auged_ind):
        aug_ann = select_ann[p_val]
        aug_bbox = aug_ann['bbox']
        x, y, w, h = int(aug_bbox[0]), int(aug_bbox[1]), int(aug_bbox[2]), int(aug_bbox[3])

        crop_bbox = crop_list[aug_crop_ind[p_ind]]
        dim = (w, h)
        resized = cv2.resize(crop_bbox, dim, interpolation=cv2.INTER_LINEAR)

        cvImage[y:y+h, x:x+w] = resized

    cv2.imwrite(os.path.join(BASE, m_file, img_info['file_name']), cvImage, [cv2.IMWRITE_JPEG_QUALITY, 100])

    for q_ind, q_val in enumerate(other_ann):
        q_val['id'] = new_ann_id
        new_ann_id += 1
        q_val['image_id'] = img_info['id']
        new_annotations_list.append(q_val)
    if len(other_ind)>0:
        for k_ind, k_val in enumerate(other_ind):
            sel_ann = select_ann[k_val]

            sel_ann['id'] = new_ann_id
            new_ann_id += 1
            sel_ann['image_id'] = img_info['id']
            new_annotations_list.append(sel_ann)

save_dict["images"] = new_images_list
save_dict["annotations"] = new_annotations_list
print("最终生成的json有 {0} 个图片".format(len(save_dict['images'])))
print("最终生成的json有 {0} 个annotation".format(len(save_dict['annotations'])))
json_str = json.dumps(save_dict, ensure_ascii=False, indent=1)
with open(os.path.join(ROOT, "061801_dustproof-aug.json"), 'w', encoding="utf-8") as json_file:
    json_file.write(json_str)


# catIds = coco.getCatIds(catNms=["dustproof"], supNms=[], catIds=[])
# imgIds = coco.getImgIds(catIds=catIds)
# #imgIds = coco.getImgIds(catIds=[])
# for i in range(len(imgIds)):
#     if i % 100 == 0:
#         print(i, "/", len(imgIds))
#     img = coco.loadImgs(imgIds[i])[0]

#     cvImage = cv2.imread(os.path.join(BASE, "train2017", img['file_name']), -1)

#     if cvImage is None:
#         print('if cvImage is None:', img['file_name'])
#         exit()

#     if cvImage.ndim == 2:
#         cvImage = cv2.cvtColor(cvImage, cv2.COLOR_GRAY2BGR)


#     annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
#     anns = coco.loadAnns(annIds)

#     catIds_len_anns = len(anns)

#     debug = 1



 

你可能感兴趣的:(深度学习)