Pytorch TTA(预测增强) 源码阅读

Pytorch TTA 源码阅读

1.ttach/wrappers.py

TTA主要调用的接口

继承了pytorch的nn.Module

import torch
import torch.nn as nn
# 做类型注解的库
# 参考 https://www.bilibili.com/read/cv3249320/
from typing import Optional, Mapping, Union, Tuple 

from .base import Merger, Compose


class SegmentationTTAWrapper(nn.Module):
    """Wrap PyTorch nn.Module (segmentation model) with test time augmentation transforms

    Args:
        model (torch.nn.Module): segmentation model with single input and single output
            (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
        transforms (ttach.Compose): composition of test time transforms
        merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
        output_mask_key (str): if model output is `dict`, specify which key belong to `mask`
    """

    def __init__(
        self,
        model: nn.Module,  # 这里是需要做TTA的训练好的model
        transforms: Compose,  # 数据增强的组合(这里实际上是一个迭代器) Compose class参见1.2
        merge_mode: str = "mean",  # 最后输出预测结果的方案
        output_mask_key: Optional[str] = None,  # Optional提示该参数是可选类型,告诉ide除了给定的默认值之外还有可能是None
    ):
        super().__init__()
        self.model = model
        self.transforms = transforms
        self.merge_mode = merge_mode
        self.output_key = output_mask_key

    def forward(
        self, image: torch.Tensor, *args
    ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
        # 初始化Merger类,Merger class参见1.1
        merger = Merger(type=self.merge_mode, n=len(self.transforms))
        # transformer是Compose的类
        # Compose class 参见1.2
        for transformer in self.transforms:
            # 由Compose class可知 transformer是一个Transformer迭代器
            augmented_image = transformer.augment_image(image)
            augmented_output = self.model(augmented_image, *args)
            # 做增强,然后送进模型
            if self.output_key is not None:
                augmented_output = augmented_output[self.output_key]
            # 这里的deaugment_mask还不确定具体要做什么
            deaugmented_output = transformer.deaugment_mask(augmented_output)
            # 放到output里
            merger.append(deaugmented_output)
        
        result = merger.result
        if self.output_key is not None:
            result = {self.output_key: result}

        return result


class ClassificationTTAWrapper(nn.Module):
    """Wrap PyTorch nn.Module (classification model) with test time augmentation transforms

    Args:
        model (torch.nn.Module): classification model with single input and single output
            (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
        transforms (ttach.Compose): composition of test time transforms
        merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
        output_label_key (str): if model output is `dict`, specify which key belong to `label`
    """

    def __init__(
        self,
        model: nn.Module,
        transforms: Compose,
        merge_mode: str = "mean",
        output_label_key: Optional[str] = None,
    ):
        super().__init__()
        self.model = model
        self.transforms = transforms
        self.merge_mode = merge_mode
        self.output_key = output_label_key

    def forward(
        self, image: torch.Tensor, *args
    ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
        merger = Merger(type=self.merge_mode, n=len(self.transforms))

        for transformer in self.transforms:
            augmented_image = transformer.augment_image(image)
            augmented_output = self.model(augmented_image, *args)
            if self.output_key is not None:
                augmented_output = augmented_output[self.output_key]
            deaugmented_output = transformer.deaugment_label(augmented_output)
            merger.append(deaugmented_output)

        result = merger.result
        if self.output_key is not None:
            result = {self.output_key: result}

        return result


class KeypointsTTAWrapper(nn.Module):
    """Wrap PyTorch nn.Module (keypoints model) with test time augmentation transforms

    Args:
        model (torch.nn.Module): keypoints model with single input and single output
         in format [x1,y1, x2, y2, ..., xn, yn]
            (.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])
        transforms (ttach.Compose): composition of test time transforms
        merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
        output_keypoints_key (str): if model output is `dict`, specify which key belong to `label`
        scaled (bool): True if model return x, y scaled values in [0, 1], else False

    """

    def __init__(
        self,
        model: nn.Module,
        transforms: Compose,
        merge_mode: str = "mean",
        output_keypoints_key: Optional[str] = None,
        scaled: bool = False,
    ):
        super().__init__()
        self.model = model
        self.transforms = transforms
        self.merge_mode = merge_mode
        self.output_key = output_keypoints_key
        self.scaled = scaled

    def forward(
        self, image: torch.Tensor, *args
    ) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:
        merger = Merger(type=self.merge_mode, n=len(self.transforms))
        size = image.size()
        batch_size, image_height, image_width = size[0], size[2], size[3]

        for transformer in self.transforms:
            augmented_image = transformer.augment_image(image)
            augmented_output = self.model(augmented_image, *args)

            if self.output_key is not None:
                augmented_output = augmented_output[self.output_key]

            augmented_output = augmented_output.reshape(batch_size, -1, 2)
            if not self.scaled:
                augmented_output[..., 0] /= image_width
                augmented_output[..., 1] /= image_height

            deaugmented_output = transformer.deaugment_keypoints(augmented_output)
            merger.append(deaugmented_output)

        result = merger.result

        if not self.scaled:
            result[..., 0] *= image_width
            result[..., 1] *= image_height
        result = result.reshape(batch_size, -1)

        if self.output_key is not None:
            result = {self.output_key: result}

        return result

1.1 Merger

class Merger:

    def __init__(
            self,
            type: str = 'mean',  # TTA预测时的方法
            n: int = 1,  # trans的个数
    ):

        if type not in ['mean', 'gmean', 'sum', 'max', 'min', 'tsharpen']:
            raise ValueError('Not correct merge type `{}`.'.format(type))

        self.output = None
        self.type = type
        self.n = n

    def append(self, x):

        if self.type == 'tsharpen':
            x = x ** 0.5
        # 这里就是output的计算,第一个output是None,这时候就把x放到output里就行
        # 后面根据预测的方法计算output,累和,累乘
        if self.output is None:
            self.output = x
        elif self.type in ['mean', 'sum', 'tsharpen']:
            self.output = self.output + x
        elif self.type == 'gmean':
            self.output = self.output * x
        elif self.type == 'max':
            self.output = F.max(self.output, x)
        elif self.type == 'min':
            self.output = F.min(self.output, x)

    # 我们可以使用@property装饰器来创建只读属性,
    # @property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。
    # 参考https://zhuanlan.zhihu.com/p/64487092
    @property
    def result(self):
        # 就是根据不同的方法返回TTA的预测值
        if self.type in ['sum', 'max', 'min']:
            result = self.output
        elif self.type in ['mean', 'tsharpen']:
            result = self.output / self.n
        elif self.type in ['gmean']:
            result = self.output ** (1 / self.n)
        else:
            raise ValueError('Not correct merge type `{}`.'.format(self.type))
        return result

1.2 Compose

class Compose:

    def __init__(
            self,
            transforms: List[BaseTransform],  # BaseTransform class 参见1.2.1
    ):
        self.aug_transforms = transforms
        # itertools.product 做成元组的迭代器
        # 假设transforms有3个,每个的params有3个,结果就应该为:
        # [ (t1.params1, t2.params1, t3.params1),
        #   (t1.params2, t2.params2, t3.params2),
        #   (t1.params3, t2.params3, t3.params3) ]
        self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms]))
        # 逆序上面俩参数,那么就是应该做一个反数据增强的操作,转换到原图
        self.deaug_transforms = transforms[::-1]
        self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters]

    # Transformer class 参见1.2.2
    # __iter__() 生成迭代器的时候调用,能用for循环调用next()方法
    # 这里用yeild生成的迭代器,所以不用再写一个next()方法
    def __iter__(self) -> Transformer:
        for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters):
        # partial表示对一个可调用对象进行操作,先传入一部分参数,做成一个有一部分参数的可调用对象,例如
        # add(x, y)需要两个参数,a = partial(add, y=1)
        # 此时调用a(2)相当于a(x=2, y=1)
            image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p})
                                     for t, p in zip(self.aug_transforms, aug_params)])
            mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p})
                                      for t, p in zip(self.deaug_transforms, deaug_params)])
            label_deaug_chain = Chain([partial(t.apply_deaug_label, **{t.pname: p})
                                       for t, p in zip(self.deaug_transforms, deaug_params)])
            keypoints_deaug_chain = Chain([partial(t.apply_deaug_keypoints, **{t.pname: p})
                                           for t, p in zip(self.deaug_transforms, deaug_params)])
            yield Transformer(
                image_pipeline=image_aug_chain,
                mask_pipeline=mask_deaug_chain,
                label_pipeline=label_deaug_chain,
                keypoints_pipeline=keypoints_deaug_chain
            )

    def __len__(self) -> int:
        return len(self.aug_transform_parameters)

1.2.1 BaseTransform

class BaseTransform:
    identity_param = None

    def __init__(
            self,
            name: str,
            params: Union[list, tuple],
    ):
        self.params = params
        self.pname = name
    # 目前看来只有一个初始化函数有用
    # raise NotImplementedError应该是表示这里的方法还没有具体实现
    # 这里表示的应该是父类抽象接口,当子类继承这个类的时候再具体去写接口就行了
    def apply_aug_image(self, image, *args, **params):
        raise NotImplementedError

    def apply_deaug_mask(self, mask, *args, **params):
        raise NotImplementedError

    def apply_deaug_label(self, label, *args, **params):
        raise NotImplementedError

    def apply_deaug_keypoints(self, keypoints, *args, **params):
        raise NotImplementedError

1.2.2 Transformer

class Transformer:
    # Chain class 参见1.2.3
    def __init__(
            self,
            image_pipeline: Chain,
            mask_pipeline: Chain,
            label_pipeline: Chain,
            keypoints_pipeline: Chain
    ):
        self.image_pipeline = image_pipeline
        self.mask_pipeline = mask_pipeline
        self.label_pipeline = label_pipeline
        self.keypoints_pipeline = keypoints_pipeline
    # Transformer类的作用就是根据传进来的Chain去调用不同函数
    def augment_image(self, image):
        return self.image_pipeline(image)

    def deaugment_mask(self, mask):
        return self.mask_pipeline(mask)

    def deaugment_label(self, label):
        return self.label_pipeline(label)

    def deaugment_keypoints(self, keypoints):
        return self.keypoints_pipeline(keypoints)

1.2.4 Chain

class Chain:
    # 实际Chain就是一系列方法的列表
    def __init__(
            self,
            functions: List[callable]  # Callable 类型是可以被执行调用操作的类型。
    ):                                 # 参考https://www.jianshu.com/p/429f00040555?			#utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
        self.functions = functions or []

    # 让Chain实例对象变为可调用的
    def __call__(self, x):
        for f in self.functions:
            x = f(x)
        return x

其实主要功能实现都在base这个文件里,因此外部调用的接口都是类似写法,主要还是base里的几个类的用法。

你可能感兴趣的:(小记)