




class DataAugmentation(object):
    to_tensor = transforms.ToTensor()
    to_image = transforms.ToPILImage()

    def __init__(self):
        super(DataAugmentation, self).__init__()
        self.transforms = transforms

    def resize(self, img, boxes, size):
        :param img: Image
        :param boxes: bbox坐标
        :param size:缩放大小
        w, h = img.size
        sw = size / w
        sh = size / h
        label, boxes = boxes[:, :1], boxes[:, 1:5]
        boxes = boxes * torch.Tensor([sw, sh, sw, sh])
        boxes = torch.cat((label, boxes), dim=1)
        return img.resize((size, size), Image.BILINEAR), boxes

    def resize_(self, img, boxes, size):
        :param img: Image
        :param boxes: bbox坐标
        :param size:缩放大小
        w, h = img.size
        # min_size = min(w, h)
        # sw = sh = size / min_size
        sw = size[0] / w
        sh = size[1] / h
        ow = int(sw * w + 0.5)
        oh = int(sh * h + 0.5)
        label, boxes = boxes[:, :1], boxes[:, 1:5]
        boxes = boxes * torch.Tensor([sw, sh, sw, sh])
        boxes = torch.cat((label, boxes), dim=1)
        return img.resize((ow, oh), Image.BILINEAR), boxes

    def random_flip_horizon(self, img, boxes):
        Horizontally flip the given image randomly with a given probability.
        :param img: Image
        :param boxes: bbox坐标
        p = torch.rand(1)
        if p > 0.5:
            transform = RandomHorizontalFlip()
            img = transform(img)
            w = img.width
            label, boxes = boxes[:, :1], boxes[:, 1:5]
            xmin = w - boxes[:, 2]
            xmax = w - boxes[:, 0]
            boxes[:, 0] = xmin
            boxes[:, 2] = xmax
            boxes = torch.cat((label, boxes), dim=1)
        return img, boxes

    def random_flip_vertical(self, img, boxes):
        Vertically flip the given image randomly with a given probability.
        :param img: Image
        :param boxes: bbox坐标
        p = torch.rand(1)
        if p > 0.5:
            transform = RandomVerticalFlip()
            img = transform(img)
            h = img.height
            label, boxes = boxes[:, :1], boxes[:, 1:5]
            ymin = h - boxes[:, 3]
            ymax = h - boxes[:, 1]
            boxes[:, 1] = ymin
            boxes[:, 3] = ymax
            boxes = torch.cat((label, boxes), dim=1)
        return img, boxes

    def center_crop(self, img, boxes, size=(600, 600)):
        :param img: Image
        :param boxes: bbox坐标
        :param size: 裁剪大小(w,h)
        w, h = img.size
        ow, oh = size
        max_size = torch.as_tensor([ow - 1, oh - 1], dtype=torch.float32)
        i = int(round((h - oh) / 2.))
        j = int(round((w - ow) / 2.))
        img = img.crop((j, i, j + ow, i + oh))
        label, boxes = boxes[:, :1], boxes[:, 1:5]
        boxes = boxes - torch.Tensor([j, i, j, i])
        boxes = torch.min(boxes.reshape(-1, 2, 2), max_size)
        boxes = boxes.clamp(min=0).reshape(-1, 4)
        boxes = torch.cat((label, boxes), dim=1)
        return img, boxes

    def random_equalize(self, img, boxes, p=0.5):
        Equalize the histogram of the given image randomly with a given probability.
        :param img: Image
        :param boxes: bbox坐标
        :param p:probability of the image being equalized
        transform = self.transforms.RandomEqualize(p=p)
        img = transform(img)
        return img, boxes

    def random_autocontrast(self, img, boxes, p=0.5):
        Autocontrast the pixels of the given image randomly with a given probability.
        :param img: Image
        :param boxes: bbox坐标
        :param p:probability of the image being autocontrasted
        transform = self.transforms.RandomAutocontrast(p=p)
        img = transform(img)
        return img, boxes

    def random_adjustSharpness(self, img, boxes, sharpness_factor=1, p=0.5):
        Adjust the sharpness of the image randomly with a given probability.
        :param img: Image
        :param boxes: bbox坐标
        :param sharpness_factor:How much to adjust the sharpness
        :param p:probability of the image being color inverted
        transform = self.transforms.RandomAdjustSharpness(sharpness_factor=sharpness_factor, p=p)
        img = transform(img)
        return img, boxes

    def random_solarize(self, img, boxes, threshold=1, p=0.5):
        Solarize the image randomly with a given probability by inverting all pixel values above a threshold.
        :param img: Image
        :param boxes: bbox坐标
        :param threshold:all pixels equal or above this value are inverted
        :param p:probability of the image being color inverted
        transform = self.transforms.RandomSolarize(threshold=threshold, p=p)
        img = transform(img)
        return img, boxes

    def random_posterize(self, img, boxes, bits=0, p=0.5):
        Posterize the image randomly with a given probability by reducing the number of bits for each color channel.
        :param img: Image
        :param boxes: bbox坐标
        :param bits:number of bits to keep for each channel (0-8)
        :param p:probability of the image being color inverted
        transform = self.transforms.RandomPosterize(bits=bits, p=p)
        img = transform(img)
        return img, boxes

    def random_grayscale(self, img, boxes, p=0.5):
        Randomly convert image to grayscale with a probability of p (default 0.1).
        :param img: Image
        :param boxes: bbox坐标
        :param p:Grayscale version of the input image with probability p and unchanged with probability (1-p).
        transform = self.transforms.RandomGrayscale(p=p)
        img = transform(img)
        return img, boxes

    def gaussian_blur(self, img, boxes, kernel_size=5, sigma=(0.1, 2.0)):
        Blurs image with randomly chosen Gaussian blur.
        :param img: Image
        :param boxes: bbox坐标
        :param kernel_size:Size of the Gaussian kernel
        :param sigma:Standard deviation to be used for creating kernel to perform blurring.
        transform = self.transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)
        img = transform(img)
        return img, boxes

    def random_invert(self, img, boxes, p=0.5):
        Inverts the colors of the given image randomly with a given probability.
        :param img: Image
        :param boxes: bbox坐标
        :param p:probability of the image being color inverted
        transform = self.transforms.RandomInvert(p=p)
        img = transform(img)
        return img, boxes

    def random_cutout_(self, img, boxes, p=0.5, scale=(0.02, 0.4), ratio=(0.4, 1 / 0.4), value=(0, 255),
                       pixel_level=False, inplace=False):
        Random erase the given CV Image
        :param img: Image
        :param boxes: bbox坐标
        :param p:probability that the random erasing operation will be performed
        :param scale:range of proportion of erased area against input image
        :param ratio:range of aspect ratio of erased area
        :param value:erasing value
        :param pixel_level:filling one number or not. Default value is False
        :param inplace:boolean to make this transform inplace. Default set to False
        transform = Cutout(p=p, scale=scale, ratio=ratio, value=value, pixel_level=pixel_level, inplace=inplace)
        img = transform(img)
        return img, boxes

    def random_cutout(self, img, boxes):
        img = np.array(img)
        img, boxes = cutout(img, boxes)
        img = Image.fromarray(img)
        return img, boxes

    def random_rotate(self, img, boxes, degrees=5, expand=False, center=None, fill=0, resample=None):
        degree = torch.randint(0, degrees + 1, (1,))
        degree = degree.item()
        transform = self.transforms.RandomRotation(degrees=degree, expand=expand, center=center, fill=fill,
        img = transform(img)
        return img, boxes

    def random_perspective(self, img, boxes, degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
                           border=(0, 0)):
        img = np.array(img)
        img, boxes = random_perspective(img, boxes.numpy(), degrees=degrees, translate=translate, scale=scale,
                                        shear=shear, perspective=perspective, border=border)
        img = Image.fromarray(img)
        return img, torch.from_numpy(boxes)

    def random_erasing(self, img, boxes, count=3, scale=0.01, ratio=0.4, value=0, inplace=False):
        Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
        :param img: Image
        :param boxes: bbox坐标
        :param scale:range of proportion of erased area against input image
        :param ratio:range of aspect ratio of erased area
        :param value:erasing value
        :param inplace:boolean to make this transform inplace. Default set to False
        scale = (scale, scale)
        ratio = (ratio, 1. / ratio)
        if count != 0:
            for num in range(count):
                transform = RandomErasing(scale=scale, ratio=ratio, value=value, inplace=inplace)
                img = transform(self.to_tensor(img))
                img = self.to_image(img)
            return img, boxes
        transform = RandomErasing(scale=scale, ratio=ratio, value=value, inplace=inplace)
        img = transform(self.to_tensor(img))
        return self.to_image(img), boxes

    def random_bright(self, img, boxes, u=32):
        :param img: Image
        :param boxes: bbox坐标
        :param u:
        img = self.to_tensor(img)
        alpha = np.random.uniform(-u, u) / 255
        img += alpha
        img = img.clamp(min=0.0, max=1.0)
        return self.to_image(img), boxes

    def random_contrast(self, img, boxes, lower=0.5, upper=1.5):
        :param img: Image
        :param boxes: bbox坐标
        :param lower:
        :param upper:
        img = self.to_tensor(img)
        alpha = np.random.uniform(lower, upper)
        img *= alpha
        img = img.clamp(min=0, max=1.0)
        return self.to_image(img), boxes

    def random_saturation(self, img, boxes, lower=0.5, upper=1.5):
        :param img: Image
        :param boxes: bbox坐标
        :param lower:
        :param upper:
        img = self.to_tensor(img)
        alpha = np.random.uniform(lower, upper)
        img[1] = img[1] * alpha
        img[1] = img[1].clamp(min=0, max=1.0)
        return self.to_image(img), boxes

    def add_gasuss_noise(self, img, boxes, mean=0, std=0.1):
        :param img: Image
        :param boxes: bbox坐标
        :param mean:
        :param std:
        img = self.to_tensor(img)
        noise = torch.normal(mean, std, img.shape)
        img += noise
        img = img.clamp(min=0, max=1.0)
        return self.to_image(img), boxes

    def add_salt_noise(self, img, boxes):
        :param img: Image
        :param boxes: bbox坐标
        img = self.to_tensor(img)
        noise = torch.rand(img.shape)
        alpha = np.random.random()
        img[noise[:, :, :] > alpha] = 1.0
        return self.to_image(img), boxes

    def add_pepper_noise(self, img, boxes):
        :param img: Image
        :param boxes: bbox坐标
        img = self.to_tensor(img)
        noise = torch.rand(img.shape)
        alpha = np.random.random()
        img[noise[:, :, :] > alpha] = 0
        return self.to_image(img), boxes

    def mixup(self, img1, img2, box1, box2, alpha=32.):
        :param img1: Image
        :param img2: Image
        :param box1: bbox1坐标
        :param box2: bbox2坐标
        :param alpha:
        p = torch.rand(1)
        if p > 0.5:
            max_w = max(img1.size[0], img2.size[0])
            max_h = max(img1.size[1], img2.size[1])
            img1, box1 = self.resize_(img1, box1, (max_w, max_h))
            img2, box2 = self.resize_(img2, box2, (max_w, max_h))

            img1 = self.to_tensor(img1)
            img2 = self.to_tensor(img2)
            weight = np.random.beta(alpha, alpha)
            miximg = weight * img1 + (1 - weight) * img2
            return self.to_image(miximg), torch.cat([box1, box2])
        return img1, box1

    def mosaic(self, imgpath, cls, imgfile, img_size=640):
        p = torch.rand(1)
        if p > 0.5:
            return load_mosaic(imgpath, cls, imgfile, img_size=img_size)
            img = Image.open(imgpath)
            jsonpath = '.'.join(imgpath.split('.')[:-1]) + '.json'
            return img, load_json_points(jsonpath, cls)

    def draw_img(self, img, boxes):
        draw = ImageDraw.Draw(img)
        for box in boxes:
            draw.rectangle(list(box[1:]), outline='yellow', width=2)










