nnDetection框架Mirror解析

如下内容是对 nnDetection框架源码中Mirror部分进行详细解析,其余部分详细解析请看博主其他文章内容,基本上每个用到的方法都进行了单独解析,请使用ctrl + f 搜索查看

Mirror类的作用是将data进行镜像操作以及将预测pred_boxes中的points进行同步镜像操作, 如下部分详细分析了镜像流程,作者利用矩阵乘法的形式快速将所有points进行镜像的思路很优雅,值得学习。

class Mirror 路径:nndet/io/transforms/spatial.py
Mirror -> AbstractTransform -> torch.nn.Module

nnDetection框架Mirror解析_第1张图片nnDetection框架Mirror解析_第2张图片

class AbstractTransform(torch.nn.Module):
    def __init__(self, grad: bool = False, **kwargs):
        """
        Args:
            grad: enable gradient computation inside transformation
        """
        super().__init__()
        self.grad = grad

    def __call__(self, *args, **kwargs) -> Any:
        """
        Call super class with correct torch context

        Args:
            *args: forwarded positional arguments
            **kwargs: forwarded keyword arguments

        Returns:
            Any: transformed data

        """
        if self.grad:
            context = torch.enable_grad()
        else:
            context = torch.no_grad()

        with context:
            return super().__call__(*args, **kwargs) # 该方法底层会调用 self.forward()方法,
class Mirror(AbstractTransform):
    def __init__(self, keys: Sequence[str], dims: Sequence[int],
                 point_keys: Sequence[str] = (), box_keys: Sequence[str] = (),
                 grad: bool = False):
        """
        Mirror Transform

        Args:
            keys: keys to mirror (first key must correspond to data for
                shape information) expected shape [N, C, dims]
            dims: dimensions to mirror (starting from the first spatial
                dimension)
            point_keys: keys where points for transformation are located
                [N, dims]
            box_keys: keys where boxes are located; following format
                needs to be used (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
            grad: enable gradient computation inside transformation
        """
        super().__init__(grad=grad)
        self.dims = dims
        self.keys = keys
        self.point_keys = point_keys
        self.box_keys = box_keys

    def forward(self, **data) -> dict:
        """
        Implement transform functionality here

        Args
            data: dict with data
            data 有两种形式, 
            transforms 中是 dict_keys(['data', 'tile_origin', 'crop'])
            inverse_transforms 是 dict_keys(['pred_boxes', 'pred_scores', 'pred_labels', 'pred_seg'])

        Returns
            dict: dict with transformed data
        """
        for key in self.keys:
            data[key] = mirror(data[key], self.dims)

        data_shape = data[self.keys[0]].shape
        data_shapes = [tuple(data_shape[2:])] * data_shape[0] # N 个 data_shape[2:]

        for key in self.box_keys: # key 就一个 'pred_boxes'
            # data['pred_boxes'] 是 list, 存放 batch_size 个 boxes, 每个boxes.shpe=(100, 6), 100 是 配置文件中的 model_detections_per_image
            points = [boxes2points(b) for b in data[key]] # 将 pred_boxes 结果(x1, y1, x2, y2, z1, z2),重新排列成[2N, dim]个点(x, y, z), 即points[0].shape==(200, 3)
            points = mirror_points(points, self.dims, data_shapes)
            data[key] = [points2boxes(p) for p in points]

        for key in self.point_keys:
            data[key] = mirror_points(data[key], self.dims, data_shapes)
        return data

    def invert(self, **data) -> dict:
        """
        Revert mirroring

        Args:
            **data: dict with data

        Returns:
            dict with re-transformed data
        """
        return self(**data) # 调用这个方法,会把以前调用 forward 翻转的数据, 再翻转回去
 def mirror(data: torch.Tensor, dims: Sequence[int]) -> torch.Tensor:
    """
    Mirror data at dims

    Args
        data: input data [N, C, spatial dims]
        dims: dimensions to mirror starting from spatial dims
            e.g. dim=(0,) mirror the first spatial dimension

    Returns
        torch.Tensor: tensor with mirrored dimensions
    """
    dims = [d + 2 for d in dims] # 因为 input data [N, C, spatial dims], 故 +2 代表data本身 spatial dims
    return data.flip(dims) # 将数据按照dims顺序, 做镜像翻转

def boxes2points(boxes: Tensor) -> Tensor:
    """
    Convert boxes to points

    Args:
        boxes: (x1, y1, x2, y2, (z1, z2))[N, dims *2]

    Returns:
        Tensor: points [N * 2, dims]
    """
    if boxes.shape[1] == 4:
        idx0 = [0, 1]
        idx1 = [2, 3]
    else:
        idx0 = [0, 1, 4]
        idx1 = [2, 3, 5]

    points0 = boxes[:, idx0]
    points1 = boxes[:, idx1]
    return torch.cat([points0, points1], dim=0) # 将 points0, points1 按行拼接起来, 得到 [2N, dim] 的数据, 每一行代表一个点

def mirror_points(points: Sequence[torch.Tensor], dims: Sequence[int],
                  data_shapes: Sequence[Sequence[int]]) -> List[torch.Tensor]:
    """
    Mirror points along given dimensions

    Args:
        points: points per batch element [N, dims]
        dims: dimensions to mirror
        data_shapes: shape of data

    Returns:
        Tensor: transformed points [N, dims]
    """
    cartesian_dims = points[0].shape[1]
    homogeneous_points = points_to_homogeneous(points) # 给 points 中每条数据(200, 3) 加一列 1, 最后得到 shape 为(200, 4)的数据

    transformed = []
    for points_per_image, data_shape in zip(homogeneous_points, data_shapes):
        matrix = nd_mirror_matrix(cartesian_dims, dims, data_shape).to(points_per_image) # 见下边nd_mirror_matrix方法解析, 得到一个(4, 4)矩阵,该矩阵包含需要镜像的维度信息以及该维度data的shape
        transformed.append(points_per_image.cpu() @ matrix.transpose(0, 1).cpu()) # 将 point 在 data 中进行翻转, 本质上就是将每个点坐标,每个维度,分别在data对应的维度上做 data.shape(dim) - point(dim) 运算
        """
        例如现在 points_per_image 是 shape 为(200, 4)的 points, 做 @(矩阵乘法)运算后,每个点都在对应维度上做了 data.shape(dim) - point(dim) 运算
        >>> a
        tensor([[ -1.,   0.,   0., 160.],
                [  0.,  -1.,   0., 112.],
                [  0.,   0.,  -1., 128.],
                [  0.,   0.,   0.,   1.]])
        >>> a.transpose(0,1)
        tensor([[ -1.,   0.,   0.,   0.],
                [  0.,  -1.,   0.,   0.],
                [  0.,   0.,  -1.,   0.],
                [160., 112., 128.,   1.]])
        """
    return points_to_cartesian(transformed)

def points_to_homogeneous(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
    """
    Transforms points from cartesian to homogeneous coordinates
	给 points 中每条数据(200, 3) 加一列1, 最后得到 shape 为(200, 4)的数据
    Args:
        points: list of points to transform [N, dims] where N is the number
            of points and dims is the number of spatial dimensions

    Returns
        torch.Tensor: the batch of points in homogeneous coordinates [N, dim + 1]
    """
    return [torch.cat([p, torch.ones(p.shape[0], 1).to(p)], dim=1) for p in points]
def nd_mirror_matrix(cartesian_dims: int, mirror_dims: Sequence[int],
                     data_shape: Sequence[int]) -> torch.Tensor:
    """
该方法构造了一个单位矩阵,当原数据为3维数据时, shape为(4, 4), 然后将需要镜像的维度mat[dim][dim]处的值
赋值为-1, 将单位矩阵最后一列上,需要镜像的维度位置的值,赋值为该维度data对应的shape
mirror_dims : (1, 2)
offset_mask : tensor([0., 1., 1.]) torch.Size([3])
mat : tensor([[  1.,   0.,   0.,   0.],
              [  0.,  -1.,   0., 112.],
              [  0.,   0.,  -1., 128.],
              [  0.,   0.,   0.,   1.]]) torch.Size([4, 4])
              
mirror_dims : (0, 1, 2)
offset_mask : tensor([1., 1., 1.]) torch.Size([3])
mat : tensor([[ -1.,   0.,   0., 160.],
              [  0.,  -1.,   0., 112.],
              [  0.,   0.,  -1., 128.],
              [  0.,   0.,   0.,   1.]]) torch.Size([4, 4])
"""
    """
    Create n dimensional matrix to for mirroring

    Args:
        cartesian_dims: number of cartesian dimensions
        mirror_dims: dimensions to mirror
        data_shape: shape of image

    Returns:
        Tensor: matrix for mirroring in homogeneous coordinated,
            [cartesian_dims + 1, cartesian_dims + 1]
    """
    mirror_dims = tuple(mirror_dims)
    data_shape = list(data_shape)

    homogeneous_dims = cartesian_dims + 1
    mat = torch.eye(homogeneous_dims, dtype=torch.float)

    # reflection
    mat[[mirror_dims] * 2] = -1 # 将单位矩阵中需要镜像的维度mat[mirr_dim][mirr_dim]设置为-1
    # add data shape to axis which were reflected
    self_tensor = torch.zeros(cartesian_dims, dtype=torch.float) # tensor([0., 0., 0.])
    index_tensor = torch.Tensor(mirror_dims).long() # tensor([需要镜像的维度])
    src_tensor = torch.tensor([1] * len(mirror_dims), dtype=torch.float) # tensor([1.] * 需要镜像的维度)
    offset_mask = self_tensor.scatter_(0, index_tensor, src_tensor) # 哪个dim需要镜像, 哪个位置上就是1, eg:dim=0:(tensor([1., 0.., 0..])), dim=(1, 2):tensor([0., 1., 1.])
    mat[:-1, -1] = offset_mask * torch.tensor(data_shape) # 如公式所示, 将单位矩阵最后一列上,需要镜像的维度位置的值,赋值为该维度data对应的shape
    return mat

def points_to_cartesian(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
    """
    Transforms points in homogeneous coordinates back to cartesian
    coordinates.

    Args:
        points: homogeneous points [N, in_dims], N number of points,
            in_dims number of input dimensions (spatial dimensions + 1)

    Returns:
        List[Tensor]]: cartesian points [N, in_dims] = [N, dims]
    """
    return [p[..., :-1] / p[..., -1][:, mirror_points] for p in points] # points.shape==(200, 4), 该处将前3列除以最后一列, 得到 形如 (200, 3) shape 的 points, 感觉这里没有必要,直接取p[..., :-1]就可以, 因为 p 最后一列 值都是1, 未曾更改过 

def points2boxes(points: Tensor) -> Tensor:
    """
    Convert points to boxes
	points:(200, 3), 该方法将这200个点, 再重新组织成(100, 6) 即 (x1, y1, x2, y2, z1, z2)的形式
    Args:
        points: boxes need to be order as specified
            order: [point_box_0, ... point_box_N/2] * 4
            format of points: (x, y(, z)))[N, dims]

    Returns:
        Tensor: bounding boxes [N / 2, dims * 2]
    """
    if points.nelement() > 0:
        points0, points1 = points.split(points.shape[0] // 2)
        boxes = torch.zeros(points.shape[0] // 2, points.shape[1] * 2).to(
            device=points.device, dtype=points.dtype)
        boxes[:, 0] = torch.min(points0[:, 0], points1[:, 0])
        boxes[:, 1] = torch.min(points0[:, 1], points1[:, 1])
        boxes[:, 2] = torch.max(points0[:, 0], points1[:, 0])
        boxes[:, 3] = torch.max(points0[:, 1], points1[:, 1])
        if boxes.shape[1] == 6:
            boxes[:, 4] = torch.min(points0[:, 2], points1[:, 2])
            boxes[:, 5] = torch.max(points0[:, 2], points1[:, 2])
        return boxes
    else:
        return torch.tensor([]).view(-1, points.shape[1] * 2).to(points)

你可能感兴趣的:(nnDetection,深度学习,视觉检测)