TTA主要调用的接口
继承了pytorch的nn.Module
import torch
import torch.nn as nn
# 做类型注解的库
# 参考 https://www.bilibili.com/read/cv3249320/
from typing import Optional, Mapping, Union, Tuple
from .base import Merger, Compose
class SegmentationTTAWrapper(nn.Module):
"""Wrap PyTorch nn.Module (segmentation model) with test time augmentation transforms
Args:
model (torch.nn.Module): segmentation model with single input and single output
(.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
transforms (ttach.Compose): composition of test time transforms
merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
output_mask_key (str): if model output is `dict`, specify which key belong to `mask`
"""
def __init__(
self,
model: nn.Module, # 这里是需要做TTA的训练好的model
transforms: Compose, # 数据增强的组合(这里实际上是一个迭代器) Compose class参见1.2
merge_mode: str = "mean", # 最后输出预测结果的方案
output_mask_key: Optional[str] = None, # Optional提示该参数是可选类型,告诉ide除了给定的默认值之外还有可能是None
):
super().__init__()
self.model = model
self.transforms = transforms
self.merge_mode = merge_mode
self.output_key = output_mask_key
def forward(
self, image: torch.Tensor, *args
) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
# 初始化Merger类,Merger class参见1.1
merger = Merger(type=self.merge_mode, n=len(self.transforms))
# transformer是Compose的类
# Compose class 参见1.2
for transformer in self.transforms:
# 由Compose class可知 transformer是一个Transformer迭代器
augmented_image = transformer.augment_image(image)
augmented_output = self.model(augmented_image, *args)
# 做增强,然后送进模型
if self.output_key is not None:
augmented_output = augmented_output[self.output_key]
# 这里的deaugment_mask还不确定具体要做什么
deaugmented_output = transformer.deaugment_mask(augmented_output)
# 放到output里
merger.append(deaugmented_output)
result = merger.result
if self.output_key is not None:
result = {self.output_key: result}
return result
class ClassificationTTAWrapper(nn.Module):
"""Wrap PyTorch nn.Module (classification model) with test time augmentation transforms
Args:
model (torch.nn.Module): classification model with single input and single output
(.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
transforms (ttach.Compose): composition of test time transforms
merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
output_label_key (str): if model output is `dict`, specify which key belong to `label`
"""
def __init__(
self,
model: nn.Module,
transforms: Compose,
merge_mode: str = "mean",
output_label_key: Optional[str] = None,
):
super().__init__()
self.model = model
self.transforms = transforms
self.merge_mode = merge_mode
self.output_key = output_label_key
def forward(
self, image: torch.Tensor, *args
) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
merger = Merger(type=self.merge_mode, n=len(self.transforms))
for transformer in self.transforms:
augmented_image = transformer.augment_image(image)
augmented_output = self.model(augmented_image, *args)
if self.output_key is not None:
augmented_output = augmented_output[self.output_key]
deaugmented_output = transformer.deaugment_label(augmented_output)
merger.append(deaugmented_output)
result = merger.result
if self.output_key is not None:
result = {self.output_key: result}
return result
class KeypointsTTAWrapper(nn.Module):
"""Wrap PyTorch nn.Module (keypoints model) with test time augmentation transforms
Args:
model (torch.nn.Module): keypoints model with single input and single output
in format [x1,y1, x2, y2, ..., xn, yn]
(.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
transforms (ttach.Compose): composition of test time transforms
merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
output_keypoints_key (str): if model output is `dict`, specify which key belong to `label`
scaled (bool): True if model return x, y scaled values in [0, 1], else False
"""
def __init__(
self,
model: nn.Module,
transforms: Compose,
merge_mode: str = "mean",
output_keypoints_key: Optional[str] = None,
scaled: bool = False,
):
super().__init__()
self.model = model
self.transforms = transforms
self.merge_mode = merge_mode
self.output_key = output_keypoints_key
self.scaled = scaled
def forward(
self, image: torch.Tensor, *args
) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
merger = Merger(type=self.merge_mode, n=len(self.transforms))
size = image.size()
batch_size, image_height, image_width = size[0], size[2], size[3]
for transformer in self.transforms:
augmented_image = transformer.augment_image(image)
augmented_output = self.model(augmented_image, *args)
if self.output_key is not None:
augmented_output = augmented_output[self.output_key]
augmented_output = augmented_output.reshape(batch_size, -1, 2)
if not self.scaled:
augmented_output[..., 0] /= image_width
augmented_output[..., 1] /= image_height
deaugmented_output = transformer.deaugment_keypoints(augmented_output)
merger.append(deaugmented_output)
result = merger.result
if not self.scaled:
result[..., 0] *= image_width
result[..., 1] *= image_height
result = result.reshape(batch_size, -1)
if self.output_key is not None:
result = {self.output_key: result}
return result
class Merger:
def __init__(
self,
type: str = 'mean', # TTA预测时的方法
n: int = 1, # trans的个数
):
if type not in ['mean', 'gmean', 'sum', 'max', 'min', 'tsharpen']:
raise ValueError('Not correct merge type `{}`.'.format(type))
self.output = None
self.type = type
self.n = n
def append(self, x):
if self.type == 'tsharpen':
x = x ** 0.5
# 这里就是output的计算,第一个output是None,这时候就把x放到output里就行
# 后面根据预测的方法计算output,累和,累乘
if self.output is None:
self.output = x
elif self.type in ['mean', 'sum', 'tsharpen']:
self.output = self.output + x
elif self.type == 'gmean':
self.output = self.output * x
elif self.type == 'max':
self.output = F.max(self.output, x)
elif self.type == 'min':
self.output = F.min(self.output, x)
# 我们可以使用@property装饰器来创建只读属性,
# @property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。
# 参考https://zhuanlan.zhihu.com/p/64487092
@property
def result(self):
# 就是根据不同的方法返回TTA的预测值
if self.type in ['sum', 'max', 'min']:
result = self.output
elif self.type in ['mean', 'tsharpen']:
result = self.output / self.n
elif self.type in ['gmean']:
result = self.output ** (1 / self.n)
else:
raise ValueError('Not correct merge type `{}`.'.format(self.type))
return result
class Compose:
def __init__(
self,
transforms: List[BaseTransform], # BaseTransform class 参见1.2.1
):
self.aug_transforms = transforms
# itertools.product 做成元组的迭代器
# 假设transforms有3个,每个的params有3个,结果就应该为:
# [ (t1.params1, t2.params1, t3.params1),
# (t1.params2, t2.params2, t3.params2),
# (t1.params3, t2.params3, t3.params3) ]
self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms]))
# 逆序上面俩参数,那么就是应该做一个反数据增强的操作,转换到原图
self.deaug_transforms = transforms[::-1]
self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters]
# Transformer class 参见1.2.2
# __iter__() 生成迭代器的时候调用,能用for循环调用next()方法
# 这里用yeild生成的迭代器,所以不用再写一个next()方法
def __iter__(self) -> Transformer:
for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters):
# partial表示对一个可调用对象进行操作,先传入一部分参数,做成一个有一部分参数的可调用对象,例如
# add(x, y)需要两个参数,a = partial(add, y=1)
# 此时调用a(2)相当于a(x=2, y=1)
image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p})
for t, p in zip(self.aug_transforms, aug_params)])
mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p})
for t, p in zip(self.deaug_transforms, deaug_params)])
label_deaug_chain = Chain([partial(t.apply_deaug_label, **{t.pname: p})
for t, p in zip(self.deaug_transforms, deaug_params)])
keypoints_deaug_chain = Chain([partial(t.apply_deaug_keypoints, **{t.pname: p})
for t, p in zip(self.deaug_transforms, deaug_params)])
yield Transformer(
image_pipeline=image_aug_chain,
mask_pipeline=mask_deaug_chain,
label_pipeline=label_deaug_chain,
keypoints_pipeline=keypoints_deaug_chain
)
def __len__(self) -> int:
return len(self.aug_transform_parameters)
class BaseTransform:
identity_param = None
def __init__(
self,
name: str,
params: Union[list, tuple],
):
self.params = params
self.pname = name
# 目前看来只有一个初始化函数有用
# raise NotImplementedError应该是表示这里的方法还没有具体实现
# 这里表示的应该是父类抽象接口,当子类继承这个类的时候再具体去写接口就行了
def apply_aug_image(self, image, *args, **params):
raise NotImplementedError
def apply_deaug_mask(self, mask, *args, **params):
raise NotImplementedError
def apply_deaug_label(self, label, *args, **params):
raise NotImplementedError
def apply_deaug_keypoints(self, keypoints, *args, **params):
raise NotImplementedError
class Transformer:
# Chain class 参见1.2.3
def __init__(
self,
image_pipeline: Chain,
mask_pipeline: Chain,
label_pipeline: Chain,
keypoints_pipeline: Chain
):
self.image_pipeline = image_pipeline
self.mask_pipeline = mask_pipeline
self.label_pipeline = label_pipeline
self.keypoints_pipeline = keypoints_pipeline
# Transformer类的作用就是根据传进来的Chain去调用不同函数
def augment_image(self, image):
return self.image_pipeline(image)
def deaugment_mask(self, mask):
return self.mask_pipeline(mask)
def deaugment_label(self, label):
return self.label_pipeline(label)
def deaugment_keypoints(self, keypoints):
return self.keypoints_pipeline(keypoints)
class Chain:
# 实际Chain就是一系列方法的列表
def __init__(
self,
functions: List[callable] # Callable 类型是可以被执行调用操作的类型。
): # 参考https://www.jianshu.com/p/429f00040555? #utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
self.functions = functions or []
# 让Chain实例对象变为可调用的
def __call__(self, x):
for f in self.functions:
x = f(x)
return x
其实主要功能实现都在base这个文件里,因此外部调用的接口都是类似写法,主要还是base里的几个类的用法。