目标检测技巧:误识别数据喂入

在一些目标检测任务中,经常有一些容易误识别的数据,比如博主最近做的台标识别的项目,经常有一些芒果之类的被误识别为湖南卫视,而实验发现mmdetection并不支持无ground truth的输入图像,那如何喂入误识别数据呢?

本博客使用了一种有人称之为“填鸭式的方法”,即将误识别的区域使用正方形标注出来,然后随机打开训练集中的一张图像,将误识别切割出来的区域放到鼠标指定的位置(鼠标指定左上角,其余自动填充)

程序中会蹦出两个cv.show的图像,第一个是误识别的数据,第二个是随机打开的训练集的图片,点完之后按随意键打开下一张图片。一个误识别区域会随机放在五张图片中。

import cv2 as cv
import os
import numpy as np
from PIL import Image
import random

global img, pt_curr, pt_pre, counts, file_list, index, tar_img, img_crop
img_crop = np.zeros((10,10))
tar_img = np.zeros((100,100))
file_list = []
counts = 0
index = 0

def on_mouse1(event, x, y, flags, param):  # 鼠标左键按下
    global pt_curr, pt_pre, index, tar_img, img_crop, tar_name
    if event == cv.EVENT_LBUTTONDOWN:
        pt_curr = (x, y)
        pt_pre = pt_curr
        h, w, _ = np.shape(img_crop)
        s = tar_name.split('/')
        tar_name = '/Users/hank/Desktop/ad_12/train_aug/' + s[-1]
        print(np.shape(img_crop), np.shape(tar_img), tar_name)
        tar_img[x:x+h, y:y+w] = img_crop
        cv.imwrite(tar_name, tar_img)

def on_mouse(event, x, y, flags, param):  # 鼠标左键按下
    global img, pt_curr, pt_pre, index
    img_cpy = img.copy()
    if event == cv.EVENT_LBUTTONDOWN:
        pt_curr = (x, y)
        pt_pre = pt_curr
        cv.circle(img_cpy, pt_curr, 10, (0, 255, 0), 2)
        cv.imshow('image', img_cpy)
    elif event == cv.EVENT_MOUSEMOVE and (flags & cv.EVENT_FLAG_LBUTTON):  # 左键保持按下,且进行拖动
        pt_curr = (x, y)
        cv.rectangle(img_cpy, pt_pre, pt_curr, (255, 0, 0), 2)
        cv.imshow('image', img_cpy)
    elif event == cv.EVENT_LBUTTONUP:  # 鼠标左键松开
        pt_curr = (x, y)
        cv.rectangle(img_cpy, pt_pre, pt_curr, (0, 0, 255), 2)
        cv.imshow('image', img_cpy)
        min_x = min(pt_pre[0], pt_curr[0])
        min_y = min(pt_pre[1], pt_curr[1])
        width = abs(pt_pre[0] - pt_curr[0])
        height = abs(pt_pre[1] - pt_curr[1])
        global img_crop
        img_crop = img[min_y:min_y + height, min_x:min_x + width]
        img_crop = Image.fromarray(img_crop)
        global counts
        counts = counts + 1
        for i in range(5):
            global tar_name
            tar_name = file_list[index]
            index += 1
            print(tar_name)
            global tar_img
            tar_img = cv.imread(tar_name)
            cv.namedWindow('tar_img')
            cv.setMouseCallback('tar_img', on_mouse1)
            cv.imshow('tar_img',tar_img)
            cv.waitKey(0)





if __name__ == '__main__':
    root_dir = './wujian'
    for filename1 in os.listdir('/Users/hank/Desktop//ad_12/train'):
        filename = os.path.join('/Users/hank/Desktop/ad_12/train', filename1)
        file_list.append(filename)
    random.shuffle(file_list)
    for img_name in os.listdir(root_dir):
        global img
        print("image name: {}".format(img_name))
        img = cv.imread(os.path.join(root_dir, img_name))
        cv.namedWindow('image')
        cv.setMouseCallback('image', on_mouse)
        cv.imshow('image', img)
        cv.waitKey(0)

 

你可能感兴趣的:(计算机视觉,python)