根据缺陷大图造到正常图片上,产生更多缺陷数据

# !/usr/bin/env python
# -- coding: utf-8 --

from skimage import transform
from skimage.util import random_noise
from random import choice, randint
import random
import cv2
import numpy as np

def zoom(img):
    h, w, _ = img.shape

    p1x, p2x = [choice(range(w)) for _ in range(2)]
    p1y, p2y = [choice(range(h)) for _ in range(2)]

    crop_p1x = max(p1x, 0)
    crop_p1y = max(p1y, 0)
    crop_p2x = min(p2x, w)
    crop_p2y = min(p2y, h)

    cropped_img = img[crop_p1y:crop_p2y, crop_p1x:crop_p2x]

    x_pad_before = -min(0, p1x)
    x_pad_after = max(0, p2x - w)
    y_pad_before = -min(0, p1y)
    y_pad_after = max(0, p2y - h)

    padding = [(y_pad_before, y_pad_after), (x_pad_before, x_pad_after)]
    is_colour = len(img.shape) == 3
    if is_colour:
        padding.append((0, 0))  # colour images have an extra dimension

    padded_img = np.pad(cropped_img, padding, 'constant')
    img = transform.resize(padded_img, (h, w))
    return img

def rotate(image, angle_range=(-90, 90), scale=1.0):
    # 随机生成旋转角度
    angle = np.random.uniform(low=angle_range[0], high=angle_range[1])
    # 获取图像尺寸
    height, width = image.shape[:2]
    # 计算旋转中心
    center = (width / 2, height / 2)
    # 定义旋转矩阵
    rotation_matrix = cv2.getRotationMatrix2D(center, angle, scale)
    # 执行旋转操作
    rotated_image = cv2.warpAffine(image, rotation_matrix, (width, height))

    return rotated_image

def hflip(image):
    img = cv2.flip(image, 1)  # #这里用到的是水平翻转
    return img

def vflip(image):
    img = cv2.flip(image, 0)  # 这里用到的是垂直翻转
    return img

def gaussianblur(image):
    img = cv2.GaussianBlur(image, (3, 3), 0)
    return img

def noise(image, noise_type='gaussian', seed=None):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    if noise_type == 'gaussian':
        noisy_image = np.copy(image)
        cv2.randn(noisy_image, 0, 20)  # 根据需要调整均值和标准差
        noisy_image = cv2.add(image, noisy_image)
    elif noise_type == 'salt_and_pepper':
        noisy_image = np.copy(image)
        h, w = image.shape[:2]
        num_pixels = int(0.02 * h * w)  # 根据需要调整噪声像素数
        coords = [np.random.randint(0, i - 1, num_pixels) for i in (h, w)]
        noisy_image[coords] = 255
        coords = [np.random.randint(0, i - 1, num_pixels) for i in (h, w)]
        noisy_image[coords] = 0
    else:
        raise ValueError('Invalid noise type')
    noisy_image_3ch = cv2.cvtColor(noisy_image, cv2.COLOR_GRAY2BGR)

    return noisy_image_3ch

def hue_image(image,saturation=50):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    v = image[:, :, 2]
    v = np.where(v <= 255 + saturation, v - saturation, 255)
    image[:, :, 2] = v
    img = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
    return img

def add_light(image, gamma=1.0):
    invGamma = 1.0 / gamma
    table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
    if image.dtype != np.uint8:
        image = cv2.convertScaleAbs(image)
    if len(image.shape) == 2:
        return cv2.LUT(image, table)
    elif len(image.shape) == 3:
        channels = cv2.split(image)
        result_channels = []
        for channel in channels:
            result_channel = cv2.LUT(channel, table)
            result_channels.append(result_channel)
        return cv2.merge(result_channels)

def multiply_image(image,R=1.25, G=1.25, B=1.25):
    img = image*[R,G,B]
    return img

def opening_image(image,shift=5):
    kernel = np.ones((shift, shift), np.uint8)
    img = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel)
    return img

def closing_image(image, shift=5):
    kernel = np.ones((shift, shift), np.uint8)
    img = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel)
    return img

def transformation_image(image):
    rows, cols, ch = image.shape
    pts1 = np.float32([[50, 50], [200, 50], [50, 200]])
    pts2 = np.float32([[10, 100], [200, 50], [100, 250]])
    M = cv2.getAffineTransform(pts1, pts2)
    img = cv2.warpAffine(image, M, (cols, rows))
    return img

def resize(image):
    h, w, _ = image.shape
    random_h = choice(range(int(0.7*h), h))
    random_w = choice(range(int(0.7*h), w))
    img = cv2.resize(image, (random_w, random_h))
    return img
# !/usr/bin/env python
# -- coding: utf-8 --

import argparse
import os
import re
import cv2
import shutil
import random
import numpy as np
from tqdm import tqdm
from imutils import paths
from xml.dom.minidom import parse
from random import choice, randint, sample
from augmentImage import rotate, hflip, vflip, gaussianblur, noise, hue_image, add_light, multiply_image, opening_image, closing_image, transformation_image
from skimage.util import random_noise
import xml.etree.ElementTree as ET
import time

def random_position(ng_pixel, ok_img, ok_x, ok_y):
    ng_h, ng_w, _ = ng_pixel.shape
    x_half = ng_w // 2
    y_half = ng_h // 2
    ok_height, ok_width, _ = ok_img.shape
    ok_xmin, ok_ymin, ok_xmax, ok_ymax = max(0, ok_x - x_half), max(0, ok_y - y_half), min(ok_x + x_half, ok_width), min(ok_height, ok_y + y_half)

    ng_width, ng_height, _ = ng_pixel.shape
    ok_pixel_h, ok_pixel_w = ok_ymax - ok_ymin, ok_xmax - ok_xmin
    if ng_width > ok_pixel_w or ng_height > ok_pixel_h:
        ng_pixel = cv2.resize(ng_pixel, (ok_pixel_w, ok_pixel_h))
    return ok_xmin, ok_ymin, ok_xmax, ok_ymax, ng_pixel


def ng_img_and_label(img_file, class_label, img_suffix='.bmp'):
    img = cv2.imread(img_file)
    xml_file = img_file.replace(img_suffix, '.xml')
    big_dom_tree = parse(xml_file)
    big_root_node = big_dom_tree.documentElement

    objects = big_root_node.getElementsByTagName('object')
    for obj in objects:
        label = obj.getElementsByTagName('name')[0].childNodes[0].nodeValue
        if label in class_label:
            xmin = int(obj.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
            ymin = int(obj.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
            xmax = int(obj.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
            ymax = int(obj.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
            select_img = img[ymin:ymax, xmin:xmax]
            return (select_img, label)
    return None

def get_random_ng_img(ng_file, ng_list, num, class_label, img_suffix='.bmp'):
    ng_hw_list = []
    available_items = set(ng_list) - {ng_file}
    get_select_ng = random.sample(available_items, num - 1) + [ng_file]
    for ng_file in get_select_ng:
        ret_ng_label = ng_img_and_label(ng_file, class_label, img_suffix)
        if ret_ng_label is not None:
            ng_hw_list.append(ret_ng_label)
    return ng_hw_list

def get_lcd_loc(xml_file, in_label='lcd'):
    big_dom_tree = parse(xml_file)
    big_root_node = big_dom_tree.documentElement

    objects = big_root_node.getElementsByTagName('object')
    for obj in objects:
        label = obj.getElementsByTagName('name')[0].childNodes[0].nodeValue
        if label == in_label:
            xmin = int(obj.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
            ymin = int(obj.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
            xmax = int(obj.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
            ymax = int(obj.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
            return [xmin, ymin, xmax, ymax]
    return []

def random_select_ok_loc(lcd_loc, x_gap=100, y_gap=50):
    lcd_width = lcd_loc[2] - lcd_loc[0]
    lcd_height = lcd_loc[3] - lcd_loc[1]
    x = randint(x_gap, lcd_width-x_gap) + lcd_loc[0]
    y = randint(y_gap, lcd_height-y_gap) + lcd_loc[1]
    return x, y

def set_position(ng_pixel, lcd_loc, select_center_x, select_center_y):
    ng_h, ng_w, _ = ng_pixel.shape
    ng_x_half = ng_w // 2
    ng_y_half = ng_h // 2
    ng_xmin = select_center_x - ng_x_half
    ng_ymin = select_center_y - ng_y_half
    ng_xmax = select_center_x + ng_x_half
    ng_ymax = select_center_y + ng_y_half

    ok_xmin, ok_ymin, ok_xmax, ok_ymax = max(lcd_loc[0], ng_xmin), max(lcd_loc[1], ng_ymin), min(lcd_loc[2], ng_xmax), min(lcd_loc[3], ng_ymax)

    ok_h, ok_w = ok_ymax - ok_ymin, ok_xmax - ok_xmin

    if ng_w > ok_w or ng_h > ok_h:
        ng_pixel = cv2.resize(ng_pixel, (ok_w, ok_h))

    return ok_xmin, ok_ymin, ok_xmax, ok_ymax, ng_pixel

def replace_pixe(ok_img, ok_loc, ng_pixel):
    ok_height, ok_width, _ = ok_img.shape
    ng_height, ng_width, _ = ng_pixel.shape
    replace_h, replace_w = (ok_loc[3] - ok_loc[1]), (ok_loc[2] - ok_loc[0])

    if ng_height != replace_h or ng_width != replace_w:
        ng_pixel = cv2.resize(ng_pixel, (replace_w, replace_h))

    ok_img[ok_loc[1]:ok_loc[3], ok_loc[0]:ok_loc[2]] = ng_pixel
    return ok_img

def save_image_and_xml(ok_image, label_loc, save_dir, ok_file_name, ok_h, ok_w, img_suffix='.bmp'):
    # Create the save directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Create an ElementTree for the XML structure
    annotation = ET.Element('annotation')

    folder = ET.SubElement(annotation, 'folder')
    folder.text = '1'

    filename = ET.SubElement(annotation, 'filename')
    filename.text = ok_file_name

    path = ET.SubElement(annotation, 'path')
    path.text = os.path.join(save_dir, ok_file_name)

    source = ET.SubElement(annotation, 'source')
    database = ET.SubElement(source, 'database')
    database.text = 'Unknown'

    size = ET.SubElement(annotation, 'size')
    width = ET.SubElement(size, 'width')
    width.text = str(ok_w)
    height = ET.SubElement(size, 'height')
    height.text = str(ok_h)
    depth = ET.SubElement(size, 'depth')
    depth.text = '3'

    segmented = ET.SubElement(annotation, 'segmented')
    segmented.text = '0'

    for label_info in label_loc:
        object_name = label_info[0]
        xmin, ymin, xmax, ymax = label_info[1:]

        obj = ET.SubElement(annotation, 'object')
        name = ET.SubElement(obj, 'name')
        name.text = object_name
        pose = ET.SubElement(obj, 'pose')
        pose.text = 'Unspecified'
        truncated = ET.SubElement(obj, 'truncated')
        truncated.text = '0'
        difficult = ET.SubElement(obj, 'difficult')
        difficult.text = '0'
        bndbox = ET.SubElement(obj, 'bndbox')
        ET.SubElement(bndbox, 'xmin').text = str(xmin)
        ET.SubElement(bndbox, 'ymin').text = str(ymin)
        ET.SubElement(bndbox, 'xmax').text = str(xmax)
        ET.SubElement(bndbox, 'ymax').text = str(ymax)

    # Create the XML tree and save it to a file
    tree = ET.ElementTree(annotation)

    current_time = time.time()

    # Convert the current time to a string in the format: YYYYMMDDHHMMSSmmm
    formatted_time = time.strftime("%Y%m%d%H%M%S%f", time.localtime(current_time))

    img_path = os.path.join(save_dir, f'NG{formatted_time}C' + ok_file_name)
    xml_path = img_path.replace(img_suffix, ".xml")

    tree.write(xml_path, encoding='utf-8', xml_declaration=True)

    cv2.imwrite(img_path, ok_image)

def ng_random_augment(img_list, normal_list, save_root, img_suffix='.bmp'):
    single_fuc_total = ['rotate', 'gaussianblur', 'noise', 'hue_image', 'add_light',
                         'multiply_image', 'opening_image', 'transformation_image']

    ok_path = choice(normal_list)
    ok_image = cv2.imread(ok_path)
    ok_file_name = os.path.split(ok_path)[1]
    ok_h, ok_w, _ = ok_image.shape
    lcd_loc = get_lcd_loc(ok_path.replace(img_suffix, '.xml'))
    if len(lcd_loc) == 0:
        print(f'error!!! {ok_path} xml not lcd node')
        return

    ng_point = []
    # each one augment execute
    for ng_image, in_label in img_list:
        in_ng_h, in_ng_w, _ = ng_image.shape
        func = choice(single_fuc_total)
        ret_image = eval(func + '(ng_image)')

        select_x, select_y = random_select_ok_loc(lcd_loc)

        xmin, ymin, xmax, ymax, ng_img = set_position(ret_image, lcd_loc, select_x, select_y)
        ng_point.append([in_label, xmin, ymin, xmax, ymax])
        ok_image = replace_pixe(ok_image, [xmin, ymin, xmax, ymax], ng_img)
    ng_point.append(['lcd', lcd_loc[0], lcd_loc[1], lcd_loc[2], lcd_loc[3]])
    save_image_and_xml(ok_image, ng_point, save_root, ok_file_name, ok_h, ok_w, img_suffix)

if __name__ == '__main__':
    # set one image have how many defact
    num = 1
    # need defact
    class_label = ['caihongbian', 'louye', 'yisedian', 'secha', 'huahen']
    # ng image and xml
    ng_root = r'/home/share_data/ext/PVDefectData/test2023/08/23/1/bigNG/'
    # not defect image
    ok_root = [
                r'/home/share_data/ext/PVDefectData/test2023/08/23/1/bigOK/'
               ]
    # save augment image path
    save_root = r'/home/share_data/ext/PVDefectData/test2023/08/23/1/save/'
    img_suffix = '.bmp'
    ng_list = list(paths.list_images(ng_root))
    ok_list = []
    for one_root in ok_root:
        ok_list += list(paths.list_images(one_root))

    for ng_file in tqdm(ng_list):
        ret_ng_img = get_random_ng_img(ng_file, ng_list, num, class_label, img_suffix)
        ng_random_augment(ret_ng_img, ok_list, save_root, img_suffix='.bmp')






你可能感兴趣的:(opencv,人工智能,计算机视觉)