pytorch版TTA 源码阅读2

TTA 源码阅读2

1. Transforms.py

主要是图片增强方法的文件

首先要提一下的是,它这里transform的类继承DualTransform,而这个类又是完全继承上一篇解析过的BaseTransform的

class DualTransform(BaseTransform):
    pass

因此看到它直接当做BaseTransform即可

其次是ImageOnlyTransform,这个也差不多,只不过子类方法提前写了三个,因为这些transform操作的逆操作是可以直接写的

class ImageOnlyTransform(BaseTransform):

    def apply_deaug_mask(self, mask, *args, **params):
        return mask

    def apply_deaug_label(self, label, *args, **params):
        return label

    def apply_deaug_keypoints(self, keypoints, *args, **params):
        return keypoints

对于Transform里的各个类,对应到BaseTransform,上一篇说过的父类方法都没写,那么在这里我们就可以看到它是不同子类"重写"父类

identity_param参数实际上就是每个操作里对应的恒等映射参数值,比如旋转时,该参数为0,代表旋转0度,那么实际上就是原图;缩放时该参数为1,缩放比例为1实际也就是原图。

from functools import partial
from typing import Optional, List, Union, Tuple
from . import functional as F
from .base import DualTransform, ImageOnlyTransform


class HorizontalFlip(DualTransform):
    """Flip images horizontally (left->right)"""
    # 水平翻转
    identity_param = False

    def __init__(self):
        super().__init__("apply", [False, True])
    # F.hflip() 就是torch.tensor.flip()
    def apply_aug_image(self, image, apply=False, **kwargs):
        if apply:
            image = F.hflip(image)
        return image

    def apply_deaug_mask(self, mask, apply=False, **kwargs):
        if apply:
            mask = F.hflip(mask)
        return mask

    def apply_deaug_label(self, label, apply=False, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
        if apply:
            # 暂时不看...
            keypoints = F.keypoints_hflip(keypoints)
        return keypoints


class VerticalFlip(DualTransform):
    """Flip images vertically (up->down)"""
    # 垂直翻转
    identity_param = False

    def __init__(self):
        super().__init__("apply", [False, True])
    # F.vflip() 就是torch.tensor.flip() 只不过相对于水平翻转的dim不一样
    def apply_aug_image(self, image, apply=False, **kwargs):
        if apply:
            image = F.vflip(image)
        return image

    def apply_deaug_mask(self, mask, apply=False, **kwargs):
        if apply:
            mask = F.vflip(mask)
        return mask

    def apply_deaug_label(self, label, apply=False, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
        if apply:
            keypoints = F.keypoints_vflip(keypoints)
        return keypoints


class Rotate90(DualTransform):
    """Rotate images 0/90/180/270 degrees

    Args:
        angles (list): angles to rotate images
    """
    # 旋转图片
    identity_param = 0

    def __init__(self, angles: List[int]):
        if self.identity_param not in angles:
            angles = [self.identity_param] + list(angles)

        super().__init__("angle", angles)

    # F.rot90(image, k) 对应就是pytorch.Tensor.rot90(k, dim = [2, 3])
    # 逆时针旋转,k是旋转次数(90度相当于k=1转一次)
    def apply_aug_image(self, image, angle=0, **kwargs):
        k = angle // 90 if angle >= 0 else (angle + 360) // 90
        return F.rot90(image, k)

    def apply_deaug_mask(self, mask, angle=0, **kwargs):
        return self.apply_aug_image(mask, -angle)

    def apply_deaug_label(self, label, angle=0, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, angle=0, **kwargs):
        angle *= -1
        k = angle // 90 if angle >= 0 else (angle + 360) // 90
        return F.keypoints_rot90(keypoints, k=k)


class Scale(DualTransform):
    """Scale images

    Args:
        scales (List[Union[int, float]]): scale factors for spatial image dimensions
        interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
        align_corners (bool): see more in torch.nn.interpolate
    """
    
    # 按比例缩小放大图片
    identity_param = 1

    def __init__(
        self,
        scales: List[Union[int, float]],
        interpolation: str = "nearest",  # nearest表示临近点插值,lenear应该是linear线性插值
        align_corners: Optional[bool] = None, # 这个参数表示是否角点对齐
    ):                                        # 可以参考https://zhuanlan.zhihu.com/p/87572724?from_voters_page=true
        if self.identity_param not in scales:
            scales = [self.identity_param] + list(scales)
        self.interpolation = interpolation
        self.align_corners = align_corners

        super().__init__("scale", scales)
        
    # F.scale()用的就是torch.nn.functional.interpolate()
    def apply_aug_image(self, image, scale=1, **kwargs):
        if scale != self.identity_param:
            image = F.scale(
                image,
                scale,
                interpolation=self.interpolation,
                align_corners=self.align_corners,
            )
        return image

    def apply_deaug_mask(self, mask, scale=1, **kwargs):
        if scale != self.identity_param:
            mask = F.scale(
                mask,
                1 / scale,
                interpolation=self.interpolation,
                align_corners=self.align_corners,
            )
        return mask

    def apply_deaug_label(self, label, scale=1, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, scale=1, **kwargs):
        return keypoints


class Resize(DualTransform):
    """Resize images

    Args:
        sizes (List[Tuple[int, int]): scale factors for spatial image dimensions
        original_size Tuple(int, int): optional, image original size for deaugmenting mask
        interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
        align_corners (bool): see more in torch.nn.interpolate
    """
    # 看用法跟scale没有区别..只是scale参数换成了size
    def __init__(
        self,
        sizes: List[Tuple[int, int]],
        original_size: Tuple[int, int] = None,
        interpolation: str = "nearest",
        align_corners: Optional[bool] = None,
    ):
        if original_size is not None and original_size not in sizes:
            sizes = [original_size] + list(sizes)
        self.interpolation = interpolation
        self.align_corners = align_corners
        self.original_size = original_size

        super().__init__("size", sizes)

    def apply_aug_image(self, image, size, **kwargs):
        if size != self.original_size:
            image = F.resize(
                image,
                size,
                interpolation=self.interpolation,
                align_corners=self.align_corners,
            )
        return image

    def apply_deaug_mask(self, mask, size, **kwargs):
        if self.original_size is None:
            raise ValueError(
                "Provide original image size to make mask backward transformation"
            )
        if size != self.original_size:
            mask = F.resize(
                mask,
                self.original_size,
                interpolation=self.interpolation,
                align_corners=self.align_corners,
            )
        return mask

    def apply_deaug_label(self, label, size=1, **kwargs):
        return label

    def apply_deaug_keypoints(self, keypoints, size=1, **kwargs):
        return keypoints


class Add(ImageOnlyTransform):
    """Add value to images

    Args:
        values (List[float]): values to add to each pixel
    """

    # 往像素点上加值
    identity_param = 0

    def __init__(self, values: List[float]):

        if self.identity_param not in values:
            values = [self.identity_param] + list(values)
        super().__init__("value", values)

    def apply_aug_image(self, image, value=0, **kwargs):
        if value != self.identity_param:
            image = F.add(image, value)
        return image


class Multiply(ImageOnlyTransform):
    """Multiply images by factor

    Args:
        factors (List[float]): factor to multiply each pixel by
    """

    # 像素点乘一个值
    identity_param = 1

    def __init__(self, factors: List[float]):
        if self.identity_param not in factors:
            factors = [self.identity_param] + list(factors)
        super().__init__("factor", factors)

    def apply_aug_image(self, image, factor=1, **kwargs):
        if factor != self.identity_param:
            image = F.multiply(image, factor)
        return image


class FiveCrops(ImageOnlyTransform):
    """Makes 4 crops for each corner + center crop

    Args:
        crop_height (int): crop height in pixels
        crop_width (int): crop width in pixels 
    """

    # 做5个crop,四角+中心
    # crop函数就是用的python索引做的裁剪
    def __init__(self, crop_height, crop_width):
        crop_functions = (
            partial(F.crop_lt, crop_h=crop_height, crop_w=crop_width),
            partial(F.crop_lb, crop_h=crop_height, crop_w=crop_width),
            partial(F.crop_rb, crop_h=crop_height, crop_w=crop_width),
            partial(F.crop_rt, crop_h=crop_height, crop_w=crop_width),
            partial(F.center_crop, crop_h=crop_height, crop_w=crop_width),
        )
        super().__init__("crop_fn", crop_functions)

    def apply_aug_image(self, image, crop_fn=None, **kwargs):
        return crop_fn(image)

    def apply_deaug_mask(self, mask, **kwargs):
        raise ValueError("`FiveCrop` augmentation is not suitable for mask!")

    def apply_deaug_keypoints(self, keypoints, **kwargs):
        raise ValueError("`FiveCrop` augmentation is not suitable for keypoints!")

1.2 functional.py

这里主要是对transform.py里方法的实现,比较简单,基本都是直接调用torch的方法,只有一个中心crop注释了一下

import torch
import torch.nn.functional as F


def rot90(x, k=1):
    """rotate batch of images by 90 degrees k times"""
    return torch.rot90(x, k, (2, 3))


def hflip(x):
    """flip batch of images horizontally"""
    return x.flip(3)


def vflip(x):
    """flip batch of images vertically"""
    return x.flip(2)


def sum(x1, x2):
    """sum of two tensors"""
    return x1 + x2


def add(x, value):
    """add value to tensor"""
    return x + value


def max(x1, x2):
    """compare 2 tensors and take max values"""
    return torch.max(x1, x2)


def min(x1, x2):
    """compare 2 tensors and take min values"""
    return torch.min(x1, x2)


def multiply(x, factor):
    """multiply tensor by factor"""
    return x * factor


def scale(x, scale_factor, interpolation="nearest", align_corners=None):
    """scale batch of images by `scale_factor` with given interpolation mode"""
    h, w = x.shape[2:]
    new_h = int(h * scale_factor)
    new_w = int(w * scale_factor)
    return F.interpolate(
        x, size=(new_h, new_w), mode=interpolation, align_corners=align_corners
    )


def resize(x, size, interpolation="nearest", align_corners=None):
    """resize batch of images to given spatial size with given interpolation mode"""
    return F.interpolate(x, size=size, mode=interpolation, align_corners=align_corners)


def crop(x, x_min=None, x_max=None, y_min=None, y_max=None):
    """perform crop on batch of images"""
    return x[:, :, y_min:y_max, x_min:x_max]


def crop_lt(x, crop_h, crop_w):
    """crop left top corner"""
    return x[:, :, 0:crop_h, 0:crop_w]


def crop_lb(x, crop_h, crop_w):
    """crop left bottom corner"""
    return x[:, :, -crop_h:, 0:crop_w]


def crop_rt(x, crop_h, crop_w):
    """crop right top corner"""
    return x[:, :, 0:crop_h, -crop_w:]


def crop_rb(x, crop_h, crop_w):
    """crop right bottom corner"""
    return x[:, :, -crop_h:, -crop_w:]


def center_crop(x, crop_h, crop_w):
    """make center crop"""

    # 先找到图片中心,再算需要crop的坐标
    center_h = x.shape[2] // 2
    center_w = x.shape[3] // 2
    half_crop_h = crop_h // 2
    half_crop_w = crop_w // 2

    y_min = center_h - half_crop_h
    y_max = center_h + half_crop_h + crop_h % 2
    x_min = center_w - half_crop_w
    x_max = center_w + half_crop_w + crop_w % 2

    return x[:, :, y_min:y_max, x_min:x_max]


def _disassemble_keypoints(keypoints):
    x = keypoints[..., 0]
    y = keypoints[..., 1]
    return x, y


def _assemble_keypoints(x, y):
    return torch.stack([x, y], dim=-1)


def keypoints_hflip(keypoints):
    x, y = _disassemble_keypoints(keypoints)
    return _assemble_keypoints(1. - x, y)


def keypoints_vflip(keypoints):
    x, y = _disassemble_keypoints(keypoints)
    return _assemble_keypoints(x, 1. - y)


def keypoints_rot90(keypoints, k=1):

    if k not in {0, 1, 2, 3}:
        raise ValueError("Parameter k must be in [0:3]")
    if k == 0:
        return keypoints
    x, y = _disassemble_keypoints(keypoints)

    if k == 1:
        xy = [y, 1. - x]
    elif k == 2:
        xy = [1. - x, 1. - y]
    elif k == 3:
        xy = [1. - y, x]

    return _assemble_keypoints(*xy)

小结

  • 可以复习基本的数据增强类型,以及对应pytorch操作
  • python对于类的继承,方法的重写
  • 关于pytorch做插值的角对齐值得深入研究一下

你可能感兴趣的:(小记)