从Coco2017中提取实例分割图片,裁剪并保存为透明背景

一、Coco2017数据集结构

总的结构如下:

├─annotations
├─train2017
└─val2017

其中,train2017、val2017存放的是对应的图片;annotations中存放的是标注文件:

2017/09/02  03:04        91,865,115 captions_train2017.json
2017/09/02  03:04         3,872,473 captions_val2017.json
2017/09/02  03:02       469,785,474 instances_train2017.json
2017/09/02  03:02        19,987,840 instances_val2017.json
2017/09/02  03:04       238,884,731 person_keypoints_train2017.json
2017/09/02  03:04        10,020,657 person_keypoints_val2017.json

二、裁剪出来的透明背景图片示意图:从Coco2017中提取实例分割图片,裁剪并保存为透明背景_第1张图片

三、完整代码如下:

import os
import cv2
import numpy as np
from pycocotools.coco import COCO

def get_coco_roi(coco_path, label, save_path):
    coco = COCO(coco_path['json_path'])

    # 采用getCatIds函数获取"person"类别对应的ID
    try:
        ids = coco.getCatIds(label)[0]
        print("%s 对应的类别ID: %d" % (label, ids))
    except:
        print("[ERROR] 请输入正确的类别名称")
        return

    # 获取某一类的所有图片集合, 比如获取包含dog的所有图片
    # imgIds = coco.catToImgs[ids]
    # print("包含 {} 的图片共有 {} 张".format(label, len(imgIds)))

    # 获取某一类的所有图片集合, 比如获取包含dog的所有图片
    imgIds = coco.getImgIds(catIds=ids)
    print("包含 {} 的图片共有 {} 张".format(label, len(imgIds)))

    for img in imgIds:
        try:
            img_info = coco.loadImgs(img)[0]
        except:
            continue

        annIds = coco.getAnnIds(imgIds=img_info['id'])
        imgpath = os.path.join(coco_path['image_path'], img_info['file_name'])
        jpg_img = cv2.imread(imgpath, 1)
        if jpg_img is None:
            continue
        print("[INFO] 打开图片: ", imgpath)

        for ann in annIds:
            outline = coco.loadAnns(ann)[0]

            # 为了得到的图像比较清晰, 将较小的图片排除
            if outline['category_id'] != ids or outline['area'] < 500.0 or outline['bbox'][2] < 20:
                continue

            try:
                contour_num = len(outline['segmentation'][0])  # Coco/json文件中有错误标注
            except:
                print("[ERROR] 遇到错误: 错误轮廓点ID是: ", outline['id'])
                continue

            contour_point = []
            print("[INFO] 在json文件 {} 类别中, 对应的id是: {}, 标注内容属于 {}".format(label, outline['id'], img_info['file_name']))
            for i in range(contour_num):
                if i % 2 == 0:
                    # 将所有的轮廓点添加到集合中
                    contour_point.append([[int(outline['segmentation'][0][i]), int(outline['segmentation'][0][i + 1])]])

            # 将list转为numpy, pointPolygonTest接口接受的数据类型
            contour_point = np.array(contour_point)
            # res1 = cv2.drawContours(jpg_img, contour_point, -1, (0, 0, 255), 2)  # -1表示绘制所有轮廓
            jpg_img_1 = jpg_img.copy()
            alpha = jpg_img.copy()
            for y in range(jpg_img.shape[1]):
                for x in range(jpg_img.shape[0]):
                    result = cv2.pointPolygonTest(contour_point, (y, x), True)  # 判断任一像素点是否位于轮廓点包围的轮廓中
                    if result <= 0:  # 返回值 <0 代表不位于轮廓中, 置为背景黑色
                        jpg_img_1[x][y] = [0, 0, 0]
                        alpha[x][y] = [0, 0, 0]  # 这里将转换为四通道中的透明背景, 所以为黑色, 代表不显示
                    else:
                        alpha[x][y] = [255, 255, 255]

            # 将图片根据bbox从原图中裁剪出
            jpg_img_1 = jpg_img_1[int(outline['bbox'][1]):int(outline['bbox'][1]+outline['bbox'][3]),
                                  int(outline['bbox'][0]):int(outline['bbox'][0]+outline['bbox'][2])]
            alpha = alpha[int(outline['bbox'][1]):int(outline['bbox'][1]+outline['bbox'][3]),
                              int(outline['bbox'][0]):int(outline['bbox'][0]+outline['bbox'][2])]

            alpha = cv2.cvtColor(alpha, cv2.COLOR_RGB2GRAY)
            b, g, r = cv2.split(jpg_img_1)
            rgba = [b, g, r, alpha]  # 将三通道图片转化为四通道(背景透明)的图片
            dst = cv2.merge(rgba, 4)

            name, shuifx = os.path.splitext(img_info['file_name'])
            imPath = os.path.join(save_path, name + "_%05d" % (int(annIds.index(ann))) + ".png")
            print("[INFO] 保存到 %s, 当前进度: %d /%d" % (imPath, imgIds.index(img), len(imgIds)))
            # cv2.imwrite(imPath, dst)
            cv2.imencode('.png', dst)[1].tofile(imPath)  # 保存中文路径的方法
        print("")

if __name__ == "__main__":
    # 定义Coco数据集根目录
    coco_root = r"E:/003 Datasets/002 CoCo2017/"

    coco_data = ['train2017', 'val2017']

    coco_path = {
        'image_path': os.path.join(coco_root, coco_data[0]),
        'json_path': coco_root + r"/annotations/instances_%s.json" % coco_data[0]}

    # 定义需要提取的类别
    labels = ["person", "dog", "cat", "bicycle", "car", "motorcycle", "bus", "truck"]
    for label in labels:
        # if label == "person" or label == "dog" or label == "cat":
        #     continue
        save_path = os.path.join(os.path.join(coco_root, "贴图数据集_train"), str(labels.index(label)) + "_" + label)
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        get_coco_roi(coco_path, label, save_path)

    print("[INFO] 抠图结束")

经过运行结果两天都没有跑完train,同事提示可以直接导出Mask再用numpy计算,故更改为以下版本,效率提升10倍,记录一下

import os
import cv2
import numpy as np
from pycocotools.coco import COCO

def get_coco_roi(ins_coco, key_coco, image_path, label, save_path):
    try:
        # 采用getCatIds函数获取"person"类别对应的ID
        ins_ids = ins_coco.getCatIds(label)[0]
        print("%s 对应的类别ID: %d" % (label, ins_ids))
    except:
        print("[ERROR] 请输入正确的类别名称")
        return

    # 获取某一类的所有图片集合, 比如获取包含dog的所有图片
    imgIds = ins_coco.getImgIds(catIds=ins_ids)
    print("包含 {} 的图片共有 {} 张".format(label, len(imgIds)))

    for img in imgIds:
        try:
            img_info = ins_coco.loadImgs(img)[0]
        except:
            continue

        annIds = ins_coco.getAnnIds(imgIds=img_info['id'])
        imgpath = os.path.join(image_path, img_info['file_name'])
        jpg_img = cv2.imread(imgpath, 1)
        if jpg_img is None:
            continue

        for ann in annIds:
            outline = ins_coco.loadAnns(ann)[0]

            # 只提取类别对应的标注信息
            if outline['category_id'] != ins_ids:
                continue

            # 对人同时使用关键点判断, 如果关键点中含有0的数量比较多, 代表这个人是不完整或姿态不好的
            if outline['category_id'] == 1:
                key_outline = key_coco.loadAnns(ann)[0]
                if key_outline['keypoints'].count(0) >= 10:
                    continue

            # 将轮廓信息转为Mask信息并转为numpy格式
            mask = ins_coco.annToMask(outline)
            mask = np.array(mask)

            # 复制并扩充维度与原图片相等, 用于后续计算
            mask_three = np.expand_dims(mask, 2).repeat(3, axis=2)

            jpg_img = np.array(jpg_img)

            # 如果mask矩阵中元素大于0, 则置为原图的像素信息, 否则置为黑色
            result = np.where(mask_three > 0, jpg_img, 0)

            # 如果mask矩阵中元素大于0, 则置为白色, 否则为黑色, 用于生成第4通道图像信息
            alpha = np.where(mask > 0, 255, 0)
            alpha = alpha.astype(np.uint8)  # 转换格式, 防止拼接时由于数据格式不匹配报错

            b, g, r = cv2.split(result)  # 分离三通道, 准备衔接上第4通道
            rgba = [b, g, r, alpha]  # 将三通道图片转化为四通道(背景透明)的图片
            dst = cv2.merge(rgba, 4)  # 拼接4个通道
            dst = dst[int(outline['bbox'][1]):int(outline['bbox'][1]+outline['bbox'][3]),
                              int(outline['bbox'][0]):int(outline['bbox'][0]+outline['bbox'][2])]

            name, shuifx = os.path.splitext(img_info['file_name'])
            imPath = os.path.join(save_path, name + "_%05d" % (int(annIds.index(ann))) + ".png")
            print("[INFO] 当前进度: %d /%d" % (imgIds.index(img), len(imgIds)))
            # cv2.imwrite(imPath, dst)
            cv2.imencode('.png', dst)[1].tofile(imPath)  # 保存中文路径的方法

if __name__ == "__main__":
    # 定义Coco数据集根目录
    coco_root = r"E:/003 Datasets/002 CoCo2017/"

    coco_data = ['train2017', 'val2017']

    # 定义需要提取的类别
    labels = ["person", "dog", "cat", "bicycle", "car", "motorcycle", "bus", "truck"]
    for data in coco_data:
        coco_path = {
            'image_path': os.path.join(coco_root, data),
            'instances_json_path': coco_root + r"/annotations/instances_%s.json" % data,
            'keypoints_json_path': coco_root + r"/annotations/person_keypoints_%s.json" % data
        }
        ins_coco = COCO(coco_path['instances_json_path'])
        key_coco = COCO(coco_path['keypoints_json_path'])

        for label in labels:
            save_path = os.path.join(os.path.join(coco_root, "贴图数据集_%s" % data), str(labels.index(label)) + "_" + label)
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            get_coco_roi(ins_coco, key_coco, coco_path['image_path'], label, save_path)

    print("[INFO] 抠图结束")

代码使用说明:更改数据集根目录、更改想抠出的类,更改保存地址运行即可

顺带感谢该博主给予的灵感:https://blog.csdn.net/oYeZhou/article/details/111994155

你可能感兴趣的:(python,计算机视觉)