pytorch版TTA 源码阅读2

class DualTransform(BaseTransform):



class ImageOnlyTransform(BaseTransform):

    def apply_deaug_mask(self, mask, *args, **params):
        return mask

    def apply_deaug_label(self, label, *args, **params):
        return label

    def apply_deaug_keypoints(self, keypoints, *args, **params):
        return keypoints



from functools import partial
from typing import Optional, List, Union, Tuple
from . import functional as F
from .base import DualTransform, ImageOnlyTransform

class HorizontalFlip(DualTransform):
    """Flip images horizontally (left->right)"""
    # 水平翻转
    identity_param = False

    def __init__(self):
        super().__init__("apply", [False, True])
    # F.hflip() 就是torch.tensor.flip()
    def apply_aug_image(self, image, apply=False, **kwargs):
        if apply:
            image = F.hflip(image)
        return image

    def apply_deaug_mask(self, mask, apply=False, **kwargs):
        if apply:
            mask = F.hflip(mask)
        return mask

    def apply_deaug_label(self, label, apply=False, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
        if apply:
            # 暂时不看...
            keypoints = F.keypoints_hflip(keypoints)
        return keypoints

class VerticalFlip(DualTransform):
    """Flip images vertically (up->down)"""
    # 垂直翻转
    identity_param = False

    def __init__(self):
        super().__init__("apply", [False, True])
    # F.vflip() 就是torch.tensor.flip() 只不过相对于水平翻转的dim不一样
    def apply_aug_image(self, image, apply=False, **kwargs):
        if apply:
            image = F.vflip(image)
        return image

    def apply_deaug_mask(self, mask, apply=False, **kwargs):
        if apply:
            mask = F.vflip(mask)
        return mask

    def apply_deaug_label(self, label, apply=False, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
        if apply:
            keypoints = F.keypoints_vflip(keypoints)
        return keypoints

class Rotate90(DualTransform):
    """Rotate images 0/90/180/270 degrees

        angles (list): angles to rotate images
    # 旋转图片
    identity_param = 0

    def __init__(self, angles: List[int]):
        if self.identity_param not in angles:
            angles = [self.identity_param] + list(angles)

        super().__init__("angle", angles)

    # F.rot90(image, k) 对应就是pytorch.Tensor.rot90(k, dim = [2, 3])
    # 逆时针旋转,k是旋转次数(90度相当于k=1转一次)
    def apply_aug_image(self, image, angle=0, **kwargs):
        k = angle // 90 if angle >= 0 else (angle + 360) // 90
        return F.rot90(image, k)

    def apply_deaug_mask(self, mask, angle=0, **kwargs):
        return self.apply_aug_image(mask, -angle)

    def apply_deaug_label(self, label, angle=0, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, angle=0, **kwargs):
        angle *= -1
        k = angle // 90 if angle >= 0 else (angle + 360) // 90
        return F.keypoints_rot90(keypoints, k=k)

class Scale(DualTransform):
    """Scale images

        scales (List[Union[int, float]]): scale factors for spatial image dimensions
        interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
        align_corners (bool): see more in torch.nn.interpolate
    # 按比例缩小放大图片
    identity_param = 1

    def __init__(
        scales: List[Union[int, float]],
        interpolation: str = "nearest",  # nearest表示临近点插值,lenear应该是linear线性插值
        align_corners: Optional[bool] = None, # 这个参数表示是否角点对齐
    ):                                        # 可以参考
        if self.identity_param not in scales:
            scales = [self.identity_param] + list(scales)
        self.interpolation = interpolation
        self.align_corners = align_corners

        super().__init__("scale", scales)
    # F.scale()用的就是torch.nn.functional.interpolate()
    def apply_aug_image(self, image, scale=1, **kwargs):
        if scale != self.identity_param:
            image = F.scale(
        return image

    def apply_deaug_mask(self, mask, scale=1, **kwargs):
        if scale != self.identity_param:
            mask = F.scale(
                1 / scale,
        return mask

    def apply_deaug_label(self, label, scale=1, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, scale=1, **kwargs):
        return keypoints

class Resize(DualTransform):
    """Resize images

        sizes (List[Tuple[int, int]): scale factors for spatial image dimensions
        original_size Tuple(int, int): optional, image original size for deaugmenting mask
        interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
        align_corners (bool): see more in torch.nn.interpolate
    # 看用法跟scale没有区别..只是scale参数换成了size
    def __init__(
        sizes: List[Tuple[int, int]],
        original_size: Tuple[int, int] = None,
        interpolation: str = "nearest",
        align_corners: Optional[bool] = None,
        if original_size is not None and original_size not in sizes:
            sizes = [original_size] + list(sizes)
        self.interpolation = interpolation
        self.align_corners = align_corners
        self.original_size = original_size

        super().__init__("size", sizes)

    def apply_aug_image(self, image, size, **kwargs):
        if size != self.original_size:
            image = F.resize(
        return image

    def apply_deaug_mask(self, mask, size, **kwargs):
        if self.original_size is None:
            raise ValueError(
                "Provide original image size to make mask backward transformation"
        if size != self.original_size:
            mask = F.resize(
        return mask

    def apply_deaug_label(self, label, size=1, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, size=1, **kwargs):
        return keypoints

class Add(ImageOnlyTransform):
    """Add value to images

        values (List[float]): values to add to each pixel

    # 往像素点上加值
    identity_param = 0

    def __init__(self, values: List[float]):

        if self.identity_param not in values:
            values = [self.identity_param] + list(values)
        super().__init__("value", values)

    def apply_aug_image(self, image, value=0, **kwargs):
        if value != self.identity_param:
            image = F.add(image, value)
        return image

class Multiply(ImageOnlyTransform):
    """Multiply images by factor

        factors (List[float]): factor to multiply each pixel by

    # 像素点乘一个值
    identity_param = 1

    def __init__(self, factors: List[float]):
        if self.identity_param not in factors:
            factors = [self.identity_param] + list(factors)
        super().__init__("factor", factors)

    def apply_aug_image(self, image, factor=1, **kwargs):
        if factor != self.identity_param:
            image = F.multiply(image, factor)
        return image

class FiveCrops(ImageOnlyTransform):
    """Makes 4 crops for each corner + center crop

        crop_height (int): crop height in pixels
        crop_width (int): crop width in pixels 

    # 做5个crop,四角+中心
    # crop函数就是用的python索引做的裁剪
    def __init__(self, crop_height, crop_width):
        crop_functions = (
            partial(F.crop_lt, crop_h=crop_height, crop_w=crop_width),
            partial(F.crop_lb, crop_h=crop_height, crop_w=crop_width),
            partial(F.crop_rb, crop_h=crop_height, crop_w=crop_width),
            partial(F.crop_rt, crop_h=crop_height, crop_w=crop_width),
            partial(F.center_crop, crop_h=crop_height, crop_w=crop_width),
        super().__init__("crop_fn", crop_functions)

    def apply_aug_image(self, image, crop_fn=None, **kwargs):
        return crop_fn(image)

    def apply_deaug_mask(self, mask, **kwargs):
        raise ValueError("`FiveCrop` augmentation is not suitable for mask!")

    def apply_deaug_keypoints(self, keypoints, **kwargs):
        raise ValueError("`FiveCrop` augmentation is not suitable for keypoints!")



import torch
import torch.nn.functional as F

def rot90(x, k=1):
    """rotate batch of images by 90 degrees k times"""
    return torch.rot90(x, k, (2, 3))

def hflip(x):
    """flip batch of images horizontally"""
    return x.flip(3)

def vflip(x):
    """flip batch of images vertically"""
    return x.flip(2)

def sum(x1, x2):
    """sum of two tensors"""
    return x1 + x2

def add(x, value):
    """add value to tensor"""
    return x + value

def max(x1, x2):
    """compare 2 tensors and take max values"""
    return torch.max(x1, x2)

def min(x1, x2):
    """compare 2 tensors and take min values"""
    return torch.min(x1, x2)

def multiply(x, factor):
    """multiply tensor by factor"""
    return x * factor

def scale(x, scale_factor, interpolation="nearest", align_corners=None):
    """scale batch of images by `scale_factor` with given interpolation mode"""
    h, w = x.shape[2:]
    new_h = int(h * scale_factor)
    new_w = int(w * scale_factor)
    return F.interpolate(
        x, size=(new_h, new_w), mode=interpolation, align_corners=align_corners

def resize(x, size, interpolation="nearest", align_corners=None):
    """resize batch of images to given spatial size with given interpolation mode"""
    return F.interpolate(x, size=size, mode=interpolation, align_corners=align_corners)

def crop(x, x_min=None, x_max=None, y_min=None, y_max=None):
    """perform crop on batch of images"""
    return x[:, :, y_min:y_max, x_min:x_max]

def crop_lt(x, crop_h, crop_w):
    """crop left top corner"""
    return x[:, :, 0:crop_h, 0:crop_w]

def crop_lb(x, crop_h, crop_w):
    """crop left bottom corner"""
    return x[:, :, -crop_h:, 0:crop_w]

def crop_rt(x, crop_h, crop_w):
    """crop right top corner"""
    return x[:, :, 0:crop_h, -crop_w:]

def crop_rb(x, crop_h, crop_w):
    """crop right bottom corner"""
    return x[:, :, -crop_h:, -crop_w:]

def center_crop(x, crop_h, crop_w):
    """make center crop"""

    # 先找到图片中心,再算需要crop的坐标
    center_h = x.shape[2] // 2
    center_w = x.shape[3] // 2
    half_crop_h = crop_h // 2
    half_crop_w = crop_w // 2

    y_min = center_h - half_crop_h
    y_max = center_h + half_crop_h + crop_h % 2
    x_min = center_w - half_crop_w
    x_max = center_w + half_crop_w + crop_w % 2

    return x[:, :, y_min:y_max, x_min:x_max]

def _disassemble_keypoints(keypoints):
    x = keypoints[..., 0]
    y = keypoints[..., 1]
    return x, y

def _assemble_keypoints(x, y):
    return torch.stack([x, y], dim=-1)

def keypoints_hflip(keypoints):
    x, y = _disassemble_keypoints(keypoints)
    return _assemble_keypoints(1. - x, y)

def keypoints_vflip(keypoints):
    x, y = _disassemble_keypoints(keypoints)
    return _assemble_keypoints(x, 1. - y)

def keypoints_rot90(keypoints, k=1):

    if k not in {0, 1, 2, 3}:
        raise ValueError("Parameter k must be in [0:3]")
    if k == 0:
        return keypoints
    x, y = _disassemble_keypoints(keypoints)

    if k == 1:
        xy = [y, 1. - x]
    elif k == 2:
        xy = [1. - x, 1. - y]
    elif k == 3:
        xy = [1. - y, x]

    return _assemble_keypoints(*xy)


  • 可以复习基本的数据增强类型,以及对应pytorch操作
  • python对于类的继承,方法的重写
  • 关于pytorch做插值的角对齐值得深入研究一下
