python opencv标注小工具,白平衡数据集标注

白平衡数据集标注工具

做了一个简单的白平衡数据集标注工具。
主要参考:https://blog.csdn.net/guyuealian/article/details/88013421

使用:

python get_illuminant_label_v2.py -mode illuminant -read_path /data/xywang/dataset/ISP_Dataset/AWB_wxy/20210507/PNG/ -save_path /data/xywang/dataset/ISP_Dataset/AWB_wxy/20210507/PNG_GT/ -save_label /data/xywang/dataset/ISP_Dataset/AWB_wxy/20210507/label.txt -save_list all_test.txt

代码:

# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
import argparse
import matplotlib.pyplot as plt

global img
global point1, point2
global g_rect
 
def on_mouse(event, x, y, flags, param):
    global img, point1, point2, g_rect
    img2 = img.copy()
    if event == cv2.EVENT_LBUTTONDOWN:  
        #print("1-EVENT_LBUTTONDOWN")
        point1 = (x, y)
        cv2.circle(img2, point1, 10, (0, 255, 0), 5)
        cv2.imshow(param, img2)
 
    elif event == cv2.EVENT_MOUSEMOVE and (flags & cv2.EVENT_FLAG_LBUTTON):  
        #print("2-EVENT_FLAG_LBUTTON")
        cv2.rectangle(img2, point1, (x, y), (255, 0, 0), thickness=2)
        cv2.imshow(param, img2)
 
    elif event == cv2.EVENT_LBUTTONUP:  
        #print("3-EVENT_LBUTTONUP")
        point2 = (x, y)
        cv2.rectangle(img2, point1, point2, (0, 0, 255), thickness=2)
        cv2.imshow(param, img2)
        if point1!=point2:
            min_x = min(point1[0], point2[0])
            min_y = min(point1[1], point2[1])
            width = abs(point1[0] - point2[0])
            height = abs(point1[1] - point2[1])
            g_rect=[min_x,min_y,width,height]
            
            cut_img = img[min_y:min_y + height, min_x:min_x + width]
            cv2.imshow('ROI', cut_img)

def get_image_roi(rgb_image,image_path):
    bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
    global img
    img=bgr_image
    cv2.namedWindow(image_path)
    while True:
        cv2.setMouseCallback(image_path, on_mouse, image_path)
        # cv2.startWindowThread()  
        cv2.imshow(image_path, img)
        key=cv2.waitKey(0)
        if key==121:
            #cv2.destroyWindow("image")
            cv2.destroyWindow("ROI")
            return g_rect, key
        if key==110:
            #cv2.destroyWindow("image")
            cv2.destroyWindow("ROI")
            pass
        if key==13 or key==32 or key==100 or key==97:
            cv2.destroyWindow(image_path)
            return None, key
    cv2.destroyAllWindows()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return g_rect, key
 
def select_user_roi(image_path):
    orig_image = read_image(image_path)
    orig_shape = np.shape(orig_image)
    image_resize = resize_image(orig_image, resize_height=800,resize_width=None)
    re_shape = np.shape(image_resize)
    rect, key = get_image_roi(image_resize,image_path)
    
    if rect is None:
        return None, None, None, None, key
    
    orgi_rect = scale_rect(rect, re_shape,orig_shape)
    roi_image = get_rect_image(orig_image,orgi_rect)
    R_mean, G_mean, B_mean = get_mean_rgb(roi_image)
    #cv_show_image("RECT",roi_image)
    #show_image_rect("image",orig_image,orgi_rect)

    return orig_image, R_mean, G_mean, B_mean, key

def select_mask_roi(image_path):
    orig_image = read_image(image_path)
    orig_shape = np.shape(orig_image)
    image_resize = resize_image(orig_image, resize_height=800,resize_width=None)
    re_shape = np.shape(image_resize)
    rect, key = get_image_roi(image_resize,image_path)
    
    if rect is None:
        return None, key
    
    orgi_rect = scale_rect(rect, re_shape,orig_shape)

    return orgi_rect, key

def show_image(title, image):
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')  
    plt.title(title)  
    plt.show()
 
def cv_show_image(title, image):
    channels=image.shape[-1]
    if channels==3:
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 
    cv2.imshow(title,image)
    cv2.waitKey(0)
 
def read_image(filename, resize_height=None, resize_width=None, normalization=False):
    bgr_image = cv2.imread(filename,-1)
    # bgr_image = cv2.imread(filename,cv2.IMREAD_IGNORE_ORIENTATION|cv2.IMREAD_COLOR)
    if bgr_image is None:
        assert False,"image open error"
    if len(bgr_image.shape) == 2:  
        print("Warning:gray image", filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
 
    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)  
    # show_image(filename,rgb_image)
    rgb_image = resize_image(rgb_image,resize_height,resize_width)
    rgb_image = np.asanyarray(rgb_image)
    if normalization:
        rgb_image = rgb_image / 255.0
    # show_image("src resize image",image)
    return rgb_image

def resize_image(image,resize_height, resize_width):
    image_shape=np.shape(image)
    height=image_shape[0]
    width=image_shape[1]
    if (resize_height is None) and (resize_width is None):
        return image
    if resize_height is None:
        resize_height=int(height*resize_width/width)
    elif resize_width is None:
        resize_width=int(width*resize_height/height)
    image = cv2.resize(image, dsize=(resize_width, resize_height))
    return image

def scale_image(image,scale):
    image = cv2.resize(image,dsize=None, fx=scale[0],fy=scale[1])
    return image
 
def get_rect_image(image,rect):
    x, y, w, h=rect
    cut_img = image[y:(y+ h),x:(x+w)]
    return cut_img

def scale_rect(orig_rect,orig_shape,dest_shape):
    new_x=int(orig_rect[0]*dest_shape[1]/orig_shape[1])
    new_y=int(orig_rect[1]*dest_shape[0]/orig_shape[0])
    new_w=int(orig_rect[2]*dest_shape[1]/orig_shape[1])
    new_h=int(orig_rect[3]*dest_shape[0]/orig_shape[0])
    dest_rect=[new_x,new_y,new_w,new_h]
    return dest_rect
 
def show_image_rect(win_name,image,rect):
    x, y, w, h=rect
    point1=(x,y)
    point2=(x+w,y+h)
    cv2.rectangle(image, point1, point2, (0, 0, 255), thickness=2)
    cv_show_image(win_name, image)
 
def rgb_to_gray(image):
    image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    return image
 
def save_image(image_path, rgb_image,toUINT8=True):
    if toUINT8:
        rgb_image = np.asanyarray(rgb_image * 255, dtype=np.uint8)
    if len(rgb_image.shape) == 2:  
        bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_GRAY2BGR)
    else:
        bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(image_path, bgr_image)
 
def combime_save_image(orig_image, dest_image, out_dir,name,prefix):
    dest_path = os.path.join(out_dir, name + "_"+prefix+".jpg")
    save_image(dest_path, dest_image)
 
    dest_image = np.hstack((orig_image, dest_image))
    save_image(os.path.join(out_dir, "{}_src_{}.jpg".format(name,prefix)), dest_image)

def get_mean_rgb(image):
    R_mean = np.mean(image[:,:,0])
    G_mean = np.mean(image[:,:,1])
    B_mean = np.mean(image[:,:,2])
    return R_mean, G_mean, B_mean
 
def get_illuminant(orig_image, R_mean, G_mean, B_mean, save_path, image_name):
    Lg = 1
    Lb = Lg*G_mean / B_mean
    Lr = Lg*G_mean / R_mean

    orig_image = orig_image.astype(np.float32) / 255
    orig_image[:,:,0] = Lr*orig_image[:,:,0]
    orig_image[:,:,2] = Lb*orig_image[:,:,2]
    orig_image = np.clip(orig_image, 0.0, 1.0)
    orig_image = (255*orig_image).astype(np.uint8)
    orig_image = cv2.cvtColor(orig_image, cv2.COLOR_RGB2BGR)

    if save_path:
        cv2.imwrite(os.path.join(save_path,image_name),orig_image)
        print("save image_awb in:",os.path.join(save_path,image_name))

    # R = orig_image[:,:,0]
    # G = orig_image[:,:,1]
    # B = orig_image[:,:,2]
    
    # B = Lb*B
    # G = Lg*G
    # R = Lr*R
    # B = np.clip(B, 0, 255)
    # G = np.clip(G, 0, 255)
    # R = np.clip(R, 0, 255)

    # B = B.astype(np.uint8)
    # G = G.astype(np.uint8)
    # R = R.astype(np.uint8)

    # tmp = cv2.merge([B,G,R])

    # if save_path:
    #     cv2.imwrite(os.path.join(save_path,image_name),tmp)
    
    return Lr, Lg, Lb

def save_illuminant(args, image_path, Lr, Lg, Lb, R_mean, G_mean, B_mean):
    # if not os.path.exists(args.save_list):
    #      with open(args.save_list,"w") as f:
    #          pass

    # if not os.path.exists(args.save_label):
    #      with open(args.save_label,"w") as f:
    #          pass

    # repeat = None
    # with open(args.save_list,"r+") as f_list:
    #     lines = f_list.readlines()
    #     for index,line in enumerate(lines):
    #         if image_path in line:
    #             repeat = index
                
    # with open(args.save_list,"w+") as f_list:
    #     if repeat is not None:
    #         del lines[repeat]
    #     f_list.writelines(lines)
    #     f_list.write(image_path)
    #     f_list.write("\n")

    # with open(args.save_label,"r+") as f_label:
    #     lines_label = f_label.readlines()
        
    # with open(args.save_label,"w+") as f_label:
    #     if repeat is not None:
    #         print("rewrite image label",image_path)
    #         del lines_label[repeat]
    #     f_label.writelines(lines_label)
    #     f_label.write(str(1/Lr))
    #     f_label.write(" ")
    #     f_label.write(str(1/Lg))
    #     f_label.write(" ")
    #     f_label.write(str(1/Lb))
    #     f_label.write("\n")

    repeat = None
    with open(args.save_label,"r") as f_label:
        lines = f_label.readlines()
        for index,line in enumerate(lines):
            if image_path in line:
                repeat = index
                
    with open(args.save_label,"w") as f_label:
        if repeat is not None:
            print("rewrite image label",image_path)
            del lines[repeat]
        f_label.writelines(lines)
        f_label.write(image_path)
        f_label.write(" ")
        f_label.write(str(1/Lr))
        f_label.write(" ")
        f_label.write(str(1/Lg))
        f_label.write(" ")
        f_label.write(str(1/Lb))
        f_label.write(" ")
        # f_label.write(str(R_mean))
        # f_label.write(" ")
        # f_label.write(str(G_mean))
        # f_label.write(" ")
        # f_label.write(str(B_mean))
        f_label.write("\n")

def save_mask(args, image_path, orig_rect):
    repeat = None
    x, y, w, h = orig_rect
    x1, y1, x2, y2 = x, y, x+w, y+h
    with open(args.save_label,"r") as f_label:
        lines = f_label.readlines()
        for index,line in enumerate(lines):
            if image_path in line:
                repeat = index
                
    with open(args.save_label,"w") as f_label:
        if repeat is not None:
            del lines[repeat]
        f_label.writelines(lines)
        f_label.write(image_path)
        f_label.write(" ")
        f_label.write(str(x1))
        f_label.write(" ")
        f_label.write(str(y1))
        f_label.write(" ")
        f_label.write(str(x2))
        f_label.write(" ")
        f_label.write(str(y2))
        f_label.write("\n")

def main():
    args = parser.parse_args()
    if args.mode == "illuminant":
        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        if not os.path.exists(args.save_list):
            with open(args.save_list,"w") as f:
                pass
        if not os.path.exists(args.save_label):
            with open(args.save_label,"w") as f:
                pass

        current_index = 0
        
        image_name_list = os.listdir(args.read_path)

        while True:
            image_path = os.path.join(args.read_path, image_name_list[current_index])
            orig_image, R_mean, G_mean, B_mean, key = select_user_roi(image_path)

            if current_index>0:
                if key == 97:
                    current_index -= 1

            if key==13 or key==32:
                print("current_index:",current_index)
                break

            if orig_image is not None:
                Lr, Lg, Lb = get_illuminant(orig_image, R_mean, G_mean, B_mean, args.save_path, image_name_list[current_index])
                if not args.save_list and not args.save_label:
                    print("Do not save label!")
                    pass
                if G_mean > 20 and G_mean< 230:
                    save_illuminant(args, image_path, Lr, Lg, Lb, R_mean, G_mean, B_mean)
                else:
                    repeat = None
                    with open(args.save_label,"r") as f_label:
                        lines = f_label.readlines()
                        for index,line in enumerate(lines):
                            if image_path in line:
                                repeat = index
                    with open(args.save_label,"w") as f_label:
                        if repeat is not None:
                            print("delete label",image_path)
                            del lines[repeat]
                        f_label.writelines(lines)
                    print("G_mean abnormal",G_mean)

            if current_index<len(image_name_list)-1:
                if key == 100:
                    current_index += 1
            else:
                print("label illuminant complete!")
                break

    elif args.mode == "mask":
        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        if not os.path.exists(args.save_list):
            with open(args.save_list,"w") as f:
                pass
        if not os.path.exists(args.save_label):
            with open(args.save_label,"w") as f:
                pass

        current_index = 0
        image_name_list = os.listdir(args.read_path)

        while True:
            image_path = os.path.join(args.read_path, image_name_list[current_index])
            orig_rect, key = select_mask_roi(image_path)

            if current_index>0:
                if key == 97:
                    current_index -= 1

            if key==13 or key==32:
                break

            if orig_rect is not None:
                if not args.save_list and not args.save_label:
                    print("Do not save label!")
                    pass
                save_mask(args, image_path, orig_rect)

            if current_index<len(image_name_list):
                if key == 100:
                    current_index += 1
            else:
                print("label illuminant complete!")
                break
    else:
        assert False, "unknow mode"

 
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-mode', help='get illuminant or mask',required=True)
    parser.add_argument('-read_path', help='path to read image',required=True)
    parser.add_argument('-save_path', help='path to save image',required=False)
    parser.add_argument('-save_label', help='path to save label',required=False)
    parser.add_argument('-save_list', help='path to save image list',required=False)

    # image_path="test.png"
    # rect=select_user_roi(image_path)
    main()
    # python get_illuminant_label_v2.py -mode illuminant -read_path /data/xywang/dataset/ISP_Dataset/AWB_wxy/20210507/PNG/ -save_path /data/xywang/dataset/ISP_Dataset/AWB_wxy/20210507/PNG_GT/ -save_label /data/xywang/dataset/ISP_Dataset/AWB_wxy/20210507/label.txt -save_list all_test.txt

RAW生成,读取,RAW和RGB图互转方法

import numpy as np
import cv2

def demosaic_opencv():
    raw_image = cv2.imread("lena_rggb.png",-1)
    print(raw_image.shape)
    rgb_image = cv2.cvtColor(raw_image,cv2.COLOR_BAYER_BG2BGR)

    print(rgb_image.shape)
    cv2.imshow("rgb_image.png",rgb_image) 
    cv2.waitKey(0)

def fun1():
    rgb_image = cv2.imread("lena_rggb.png")
    row = rgb_image.shape[0]
    col = rgb_image.shape[1]

    #assume bayer=RG
    rgb_image[0:row:2, 0:col:2, 0:2] = 0,0
    rgb_image[0:row:2, 1:col:2, [0,2]] = 0,0
    rgb_image[1:row:2, 0:col:2, [0,2]] = 0,0
    rgb_image[1:row:2, 1:col:2, 1:3] = 0,0

    cv2.imshow('lena_rgb.png', rgb_image)
    cv2.waitKey(0)
    #cv2.imwrite('images/lena_rgb.png', rgb_image)

def fun2():
    raw_image = cv2.imread("lena_rggb.png",-1).astype(np.float32)
    row = raw_image.shape[0]
    col = raw_image.shape[1]
    rgb_row = row//2
    rgb_col = col//2

    rgb_image = np.zeros((rgb_row, rgb_col, 3),dtype=np.float32)
    for i in range(rgb_row):
        for j in range(rgb_col):
            rgb_image[i,j,0] = raw_image[i*2+1,j*2+1]
            rgb_image[i,j,1] = (raw_image[i*2+1,j*2] + raw_image[i*2,j*2+1])//2
            rgb_image[i,j,2] = raw_image[i*2,j*2]
    rgb_image = rgb_image.astype(np.uint8)
    raw_image = raw_image.astype(np.uint8)
    print(rgb_image.shape)
    cv2.imshow('raw_image.png', raw_image)
    cv2.imshow('lena_rgb.png', rgb_image)
    cv2.waitKey(0)

def rgb2raw():
    img = cv2.imread('PNG/vlcsnap-2021-04-20-15h28m27s289.png', cv2.IMREAD_COLOR)
    print(img.shape)
    row, col, chl = img.shape
    ext_x, ext_y = row%4, col%4
    img = img[0:row-ext_x, 0:col-ext_y, :]
    row, col, chl = img.shape

    #assume bayer=RG
    img[0:row:2, 0:col:2, 0:2]=0,0
    img[0:row:2, 1:col:2, [0,2]]=0,0
    img[1:row:2, 0:col:2, [0,2]]=0,0
    img[1:row:2, 1:col:2, 1:3]=0,0
    #cv2.imwrite('images/vlcsnap-2021-04-20-15h28m27s289_raw.png', img)

    raw = np.zeros((row, col), dtype=np.uint8)
    raw[0:row:2, 0:col:2] = img[0:row:2, 0:col:2, 2]#red
    raw[0:row:2, 1:col:2] = img[0:row:2, 1:col:2, 1]#green
    raw[1:row:2, 0:col:2] = img[1:row:2, 0:col:2, 1]#green
    raw[1:row:2, 1:col:2] = img[1:row:2, 1:col:2, 0]#blue
    #raw.tofile('images/vlcsnap-2021-04-20-15h28m27s289_raw.raw')#8bit

def show_raw():
    imgData = np.fromfile('images/indoor_1920_1080_rggb_16bit.raw', dtype=np.uint16)
    imgData = imgData.reshape(1080, 1920, 1)
    img = imgData.astype(np.uint8)
    print(imgData.shape)

    cv2.imshow("rgb_image.png",imgData) 
    cv2.waitKey(0)

if __name__ == "__main__":
    #fun1()
    #fun2()
    #rgb2raw()
    show_raw()
    #demosaic_opencv()

gamma

def GAMMA(im,gamma=2.4):
    im_type_max = np.iinfo(im.dtype).max
    im = im.astype(np.float) / im_type_max
    a = 0.055
    b = 1.055
    t = 0.0031308
    im[im < 0] = 0
    im[im > 1] = 1
    im = np.multiply(im * 12.92, im < t) + np.multiply(b*np.power(im, (1/gamma)) - a, im >= t)
    im[im < 0] = 0
    im[im > 1] = 1
    im = (im * im_type_max).astype(np.uint8)
    return im

运动模糊

def motion_blur(image, base_degree=10, base_angle=10):
    # degree = int(random.uniform(base_degree-5,base_degree+5))
    # angle = int(random.uniform(base_angle-5,base_angle+5))
    degree = base_degree
    angle = base_angle
    image = np.array(image)
    M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
    motion_blur_kernel = np.diag(np.ones(degree))
    motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
    motion_blur_kernel = motion_blur_kernel / degree
    blurred = cv2.filter2D(image, -1, motion_blur_kernel)
    cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
    blurred = np.array(blurred, dtype=np.uint8)
    return blurred

你可能感兴趣的:(AI,ISP,python,opencv,python,opencv,pytorch)