主要是图片增强方法的文件
首先要提一下的是,它这里transform的类继承DualTransform,而这个类又是完全继承上一篇解析过的BaseTransform的
class DualTransform(BaseTransform):
pass
因此看到它直接当做BaseTransform即可
其次是ImageOnlyTransform,这个也差不多,只不过子类方法提前写了三个,因为这些transform操作的逆操作是可以直接写的
class ImageOnlyTransform(BaseTransform):
def apply_deaug_mask(self, mask, *args, **params):
return mask
def apply_deaug_label(self, label, *args, **params):
return label
def apply_deaug_keypoints(self, keypoints, *args, **params):
return keypoints
对于Transform里的各个类,对应到BaseTransform,上一篇说过的父类方法都没写,那么在这里我们就可以看到它是不同子类"重写"父类
identity_param参数实际上就是每个操作里对应的恒等映射参数值,比如旋转时,该参数为0,代表旋转0度,那么实际上就是原图;缩放时该参数为1,缩放比例为1实际也就是原图。
from functools import partial
from typing import Optional, List, Union, Tuple
from . import functional as F
from .base import DualTransform, ImageOnlyTransform
class HorizontalFlip(DualTransform):
"""Flip images horizontally (left->right)"""
# 水平翻转
identity_param = False
def __init__(self):
super().__init__("apply", [False, True])
# F.hflip() 就是torch.tensor.flip()
def apply_aug_image(self, image, apply=False, **kwargs):
if apply:
image = F.hflip(image)
return image
def apply_deaug_mask(self, mask, apply=False, **kwargs):
if apply:
mask = F.hflip(mask)
return mask
def apply_deaug_label(self, label, apply=False, **kwargs):
return label
def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
if apply:
# 暂时不看...
keypoints = F.keypoints_hflip(keypoints)
return keypoints
class VerticalFlip(DualTransform):
"""Flip images vertically (up->down)"""
# 垂直翻转
identity_param = False
def __init__(self):
super().__init__("apply", [False, True])
# F.vflip() 就是torch.tensor.flip() 只不过相对于水平翻转的dim不一样
def apply_aug_image(self, image, apply=False, **kwargs):
if apply:
image = F.vflip(image)
return image
def apply_deaug_mask(self, mask, apply=False, **kwargs):
if apply:
mask = F.vflip(mask)
return mask
def apply_deaug_label(self, label, apply=False, **kwargs):
return label
def apply_deaug_keypoints(self, keypoints, apply=False, **kwargs):
if apply:
keypoints = F.keypoints_vflip(keypoints)
return keypoints
class Rotate90(DualTransform):
"""Rotate images 0/90/180/270 degrees
Args:
angles (list): angles to rotate images
"""
# 旋转图片
identity_param = 0
def __init__(self, angles: List[int]):
if self.identity_param not in angles:
angles = [self.identity_param] + list(angles)
super().__init__("angle", angles)
# F.rot90(image, k) 对应就是pytorch.Tensor.rot90(k, dim = [2, 3])
# 逆时针旋转,k是旋转次数(90度相当于k=1转一次)
def apply_aug_image(self, image, angle=0, **kwargs):
k = angle // 90 if angle >= 0 else (angle + 360) // 90
return F.rot90(image, k)
def apply_deaug_mask(self, mask, angle=0, **kwargs):
return self.apply_aug_image(mask, -angle)
def apply_deaug_label(self, label, angle=0, **kwargs):
return label
def apply_deaug_keypoints(self, keypoints, angle=0, **kwargs):
angle *= -1
k = angle // 90 if angle >= 0 else (angle + 360) // 90
return F.keypoints_rot90(keypoints, k=k)
class Scale(DualTransform):
"""Scale images
Args:
scales (List[Union[int, float]]): scale factors for spatial image dimensions
interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
align_corners (bool): see more in torch.nn.interpolate
"""
# 按比例缩小放大图片
identity_param = 1
def __init__(
self,
scales: List[Union[int, float]],
interpolation: str = "nearest", # nearest表示临近点插值,lenear应该是linear线性插值
align_corners: Optional[bool] = None, # 这个参数表示是否角点对齐
): # 可以参考https://zhuanlan.zhihu.com/p/87572724?from_voters_page=true
if self.identity_param not in scales:
scales = [self.identity_param] + list(scales)
self.interpolation = interpolation
self.align_corners = align_corners
super().__init__("scale", scales)
# F.scale()用的就是torch.nn.functional.interpolate()
def apply_aug_image(self, image, scale=1, **kwargs):
if scale != self.identity_param:
image = F.scale(
image,
scale,
interpolation=self.interpolation,
align_corners=self.align_corners,
)
return image
def apply_deaug_mask(self, mask, scale=1, **kwargs):
if scale != self.identity_param:
mask = F.scale(
mask,
1 / scale,
interpolation=self.interpolation,
align_corners=self.align_corners,
)
return mask
def apply_deaug_label(self, label, scale=1, **kwargs):
return label
def apply_deaug_keypoints(self, keypoints, scale=1, **kwargs):
return keypoints
class Resize(DualTransform):
"""Resize images
Args:
sizes (List[Tuple[int, int]): scale factors for spatial image dimensions
original_size Tuple(int, int): optional, image original size for deaugmenting mask
interpolation (str): one of "nearest"/"lenear" (see more in torch.nn.interpolate)
align_corners (bool): see more in torch.nn.interpolate
"""
# 看用法跟scale没有区别..只是scale参数换成了size
def __init__(
self,
sizes: List[Tuple[int, int]],
original_size: Tuple[int, int] = None,
interpolation: str = "nearest",
align_corners: Optional[bool] = None,
):
if original_size is not None and original_size not in sizes:
sizes = [original_size] + list(sizes)
self.interpolation = interpolation
self.align_corners = align_corners
self.original_size = original_size
super().__init__("size", sizes)
def apply_aug_image(self, image, size, **kwargs):
if size != self.original_size:
image = F.resize(
image,
size,
interpolation=self.interpolation,
align_corners=self.align_corners,
)
return image
def apply_deaug_mask(self, mask, size, **kwargs):
if self.original_size is None:
raise ValueError(
"Provide original image size to make mask backward transformation"
)
if size != self.original_size:
mask = F.resize(
mask,
self.original_size,
interpolation=self.interpolation,
align_corners=self.align_corners,
)
return mask
def apply_deaug_label(self, label, size=1, **kwargs):
return label
def apply_deaug_keypoints(self, keypoints, size=1, **kwargs):
return keypoints
class Add(ImageOnlyTransform):
"""Add value to images
Args:
values (List[float]): values to add to each pixel
"""
# 往像素点上加值
identity_param = 0
def __init__(self, values: List[float]):
if self.identity_param not in values:
values = [self.identity_param] + list(values)
super().__init__("value", values)
def apply_aug_image(self, image, value=0, **kwargs):
if value != self.identity_param:
image = F.add(image, value)
return image
class Multiply(ImageOnlyTransform):
"""Multiply images by factor
Args:
factors (List[float]): factor to multiply each pixel by
"""
# 像素点乘一个值
identity_param = 1
def __init__(self, factors: List[float]):
if self.identity_param not in factors:
factors = [self.identity_param] + list(factors)
super().__init__("factor", factors)
def apply_aug_image(self, image, factor=1, **kwargs):
if factor != self.identity_param:
image = F.multiply(image, factor)
return image
class FiveCrops(ImageOnlyTransform):
"""Makes 4 crops for each corner + center crop
Args:
crop_height (int): crop height in pixels
crop_width (int): crop width in pixels
"""
# 做5个crop,四角+中心
# crop函数就是用的python索引做的裁剪
def __init__(self, crop_height, crop_width):
crop_functions = (
partial(F.crop_lt, crop_h=crop_height, crop_w=crop_width),
partial(F.crop_lb, crop_h=crop_height, crop_w=crop_width),
partial(F.crop_rb, crop_h=crop_height, crop_w=crop_width),
partial(F.crop_rt, crop_h=crop_height, crop_w=crop_width),
partial(F.center_crop, crop_h=crop_height, crop_w=crop_width),
)
super().__init__("crop_fn", crop_functions)
def apply_aug_image(self, image, crop_fn=None, **kwargs):
return crop_fn(image)
def apply_deaug_mask(self, mask, **kwargs):
raise ValueError("`FiveCrop` augmentation is not suitable for mask!")
def apply_deaug_keypoints(self, keypoints, **kwargs):
raise ValueError("`FiveCrop` augmentation is not suitable for keypoints!")
这里主要是对transform.py里方法的实现,比较简单,基本都是直接调用torch的方法,只有一个中心crop注释了一下
import torch
import torch.nn.functional as F
def rot90(x, k=1):
"""rotate batch of images by 90 degrees k times"""
return torch.rot90(x, k, (2, 3))
def hflip(x):
"""flip batch of images horizontally"""
return x.flip(3)
def vflip(x):
"""flip batch of images vertically"""
return x.flip(2)
def sum(x1, x2):
"""sum of two tensors"""
return x1 + x2
def add(x, value):
"""add value to tensor"""
return x + value
def max(x1, x2):
"""compare 2 tensors and take max values"""
return torch.max(x1, x2)
def min(x1, x2):
"""compare 2 tensors and take min values"""
return torch.min(x1, x2)
def multiply(x, factor):
"""multiply tensor by factor"""
return x * factor
def scale(x, scale_factor, interpolation="nearest", align_corners=None):
"""scale batch of images by `scale_factor` with given interpolation mode"""
h, w = x.shape[2:]
new_h = int(h * scale_factor)
new_w = int(w * scale_factor)
return F.interpolate(
x, size=(new_h, new_w), mode=interpolation, align_corners=align_corners
)
def resize(x, size, interpolation="nearest", align_corners=None):
"""resize batch of images to given spatial size with given interpolation mode"""
return F.interpolate(x, size=size, mode=interpolation, align_corners=align_corners)
def crop(x, x_min=None, x_max=None, y_min=None, y_max=None):
"""perform crop on batch of images"""
return x[:, :, y_min:y_max, x_min:x_max]
def crop_lt(x, crop_h, crop_w):
"""crop left top corner"""
return x[:, :, 0:crop_h, 0:crop_w]
def crop_lb(x, crop_h, crop_w):
"""crop left bottom corner"""
return x[:, :, -crop_h:, 0:crop_w]
def crop_rt(x, crop_h, crop_w):
"""crop right top corner"""
return x[:, :, 0:crop_h, -crop_w:]
def crop_rb(x, crop_h, crop_w):
"""crop right bottom corner"""
return x[:, :, -crop_h:, -crop_w:]
def center_crop(x, crop_h, crop_w):
"""make center crop"""
# 先找到图片中心,再算需要crop的坐标
center_h = x.shape[2] // 2
center_w = x.shape[3] // 2
half_crop_h = crop_h // 2
half_crop_w = crop_w // 2
y_min = center_h - half_crop_h
y_max = center_h + half_crop_h + crop_h % 2
x_min = center_w - half_crop_w
x_max = center_w + half_crop_w + crop_w % 2
return x[:, :, y_min:y_max, x_min:x_max]
def _disassemble_keypoints(keypoints):
x = keypoints[..., 0]
y = keypoints[..., 1]
return x, y
def _assemble_keypoints(x, y):
return torch.stack([x, y], dim=-1)
def keypoints_hflip(keypoints):
x, y = _disassemble_keypoints(keypoints)
return _assemble_keypoints(1. - x, y)
def keypoints_vflip(keypoints):
x, y = _disassemble_keypoints(keypoints)
return _assemble_keypoints(x, 1. - y)
def keypoints_rot90(keypoints, k=1):
if k not in {0, 1, 2, 3}:
raise ValueError("Parameter k must be in [0:3]")
if k == 0:
return keypoints
x, y = _disassemble_keypoints(keypoints)
if k == 1:
xy = [y, 1. - x]
elif k == 2:
xy = [1. - x, 1. - y]
elif k == 3:
xy = [1. - y, x]
return _assemble_keypoints(*xy)