如下内容是对 nnDetection框架源码中Mirror部分进行详细解析,其余部分详细解析请看博主其他文章内容,基本上每个用到的方法都进行了单独解析,请使用ctrl + f 搜索查看
Mirror
类的作用是将data
进行镜像操作以及将预测pred_boxes
中的points
进行同步镜像操作, 如下部分详细分析了镜像流程,作者利用矩阵乘法的形式快速将所有points进行镜像的思路很优雅,值得学习。
class Mirror 路径:nndet/io/transforms/spatial.py
Mirror -> AbstractTransform -> torch.nn.Module
class AbstractTransform(torch.nn.Module):
def __init__(self, grad: bool = False, **kwargs):
"""
Args:
grad: enable gradient computation inside transformation
"""
super().__init__()
self.grad = grad
def __call__(self, *args, **kwargs) -> Any:
"""
Call super class with correct torch context
Args:
*args: forwarded positional arguments
**kwargs: forwarded keyword arguments
Returns:
Any: transformed data
"""
if self.grad:
context = torch.enable_grad()
else:
context = torch.no_grad()
with context:
return super().__call__(*args, **kwargs) # 该方法底层会调用 self.forward()方法,
class Mirror(AbstractTransform):
def __init__(self, keys: Sequence[str], dims: Sequence[int],
point_keys: Sequence[str] = (), box_keys: Sequence[str] = (),
grad: bool = False):
"""
Mirror Transform
Args:
keys: keys to mirror (first key must correspond to data for
shape information) expected shape [N, C, dims]
dims: dimensions to mirror (starting from the first spatial
dimension)
point_keys: keys where points for transformation are located
[N, dims]
box_keys: keys where boxes are located; following format
needs to be used (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
grad: enable gradient computation inside transformation
"""
super().__init__(grad=grad)
self.dims = dims
self.keys = keys
self.point_keys = point_keys
self.box_keys = box_keys
def forward(self, **data) -> dict:
"""
Implement transform functionality here
Args
data: dict with data
data 有两种形式,
transforms 中是 dict_keys(['data', 'tile_origin', 'crop'])
inverse_transforms 是 dict_keys(['pred_boxes', 'pred_scores', 'pred_labels', 'pred_seg'])
Returns
dict: dict with transformed data
"""
for key in self.keys:
data[key] = mirror(data[key], self.dims)
data_shape = data[self.keys[0]].shape
data_shapes = [tuple(data_shape[2:])] * data_shape[0] # N 个 data_shape[2:]
for key in self.box_keys: # key 就一个 'pred_boxes'
# data['pred_boxes'] 是 list, 存放 batch_size 个 boxes, 每个boxes.shpe=(100, 6), 100 是 配置文件中的 model_detections_per_image
points = [boxes2points(b) for b in data[key]] # 将 pred_boxes 结果(x1, y1, x2, y2, z1, z2),重新排列成[2N, dim]个点(x, y, z), 即points[0].shape==(200, 3)
points = mirror_points(points, self.dims, data_shapes)
data[key] = [points2boxes(p) for p in points]
for key in self.point_keys:
data[key] = mirror_points(data[key], self.dims, data_shapes)
return data
def invert(self, **data) -> dict:
"""
Revert mirroring
Args:
**data: dict with data
Returns:
dict with re-transformed data
"""
return self(**data) # 调用这个方法,会把以前调用 forward 翻转的数据, 再翻转回去
def mirror(data: torch.Tensor, dims: Sequence[int]) -> torch.Tensor:
"""
Mirror data at dims
Args
data: input data [N, C, spatial dims]
dims: dimensions to mirror starting from spatial dims
e.g. dim=(0,) mirror the first spatial dimension
Returns
torch.Tensor: tensor with mirrored dimensions
"""
dims = [d + 2 for d in dims] # 因为 input data [N, C, spatial dims], 故 +2 代表data本身 spatial dims
return data.flip(dims) # 将数据按照dims顺序, 做镜像翻转
def boxes2points(boxes: Tensor) -> Tensor:
"""
Convert boxes to points
Args:
boxes: (x1, y1, x2, y2, (z1, z2))[N, dims *2]
Returns:
Tensor: points [N * 2, dims]
"""
if boxes.shape[1] == 4:
idx0 = [0, 1]
idx1 = [2, 3]
else:
idx0 = [0, 1, 4]
idx1 = [2, 3, 5]
points0 = boxes[:, idx0]
points1 = boxes[:, idx1]
return torch.cat([points0, points1], dim=0) # 将 points0, points1 按行拼接起来, 得到 [2N, dim] 的数据, 每一行代表一个点
def mirror_points(points: Sequence[torch.Tensor], dims: Sequence[int],
data_shapes: Sequence[Sequence[int]]) -> List[torch.Tensor]:
"""
Mirror points along given dimensions
Args:
points: points per batch element [N, dims]
dims: dimensions to mirror
data_shapes: shape of data
Returns:
Tensor: transformed points [N, dims]
"""
cartesian_dims = points[0].shape[1]
homogeneous_points = points_to_homogeneous(points) # 给 points 中每条数据(200, 3) 加一列 1, 最后得到 shape 为(200, 4)的数据
transformed = []
for points_per_image, data_shape in zip(homogeneous_points, data_shapes):
matrix = nd_mirror_matrix(cartesian_dims, dims, data_shape).to(points_per_image) # 见下边nd_mirror_matrix方法解析, 得到一个(4, 4)矩阵,该矩阵包含需要镜像的维度信息以及该维度data的shape
transformed.append(points_per_image.cpu() @ matrix.transpose(0, 1).cpu()) # 将 point 在 data 中进行翻转, 本质上就是将每个点坐标,每个维度,分别在data对应的维度上做 data.shape(dim) - point(dim) 运算
"""
例如现在 points_per_image 是 shape 为(200, 4)的 points, 做 @(矩阵乘法)运算后,每个点都在对应维度上做了 data.shape(dim) - point(dim) 运算
>>> a
tensor([[ -1., 0., 0., 160.],
[ 0., -1., 0., 112.],
[ 0., 0., -1., 128.],
[ 0., 0., 0., 1.]])
>>> a.transpose(0,1)
tensor([[ -1., 0., 0., 0.],
[ 0., -1., 0., 0.],
[ 0., 0., -1., 0.],
[160., 112., 128., 1.]])
"""
return points_to_cartesian(transformed)
def points_to_homogeneous(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
"""
Transforms points from cartesian to homogeneous coordinates
给 points 中每条数据(200, 3) 加一列1, 最后得到 shape 为(200, 4)的数据
Args:
points: list of points to transform [N, dims] where N is the number
of points and dims is the number of spatial dimensions
Returns
torch.Tensor: the batch of points in homogeneous coordinates [N, dim + 1]
"""
return [torch.cat([p, torch.ones(p.shape[0], 1).to(p)], dim=1) for p in points]
def nd_mirror_matrix(cartesian_dims: int, mirror_dims: Sequence[int],
data_shape: Sequence[int]) -> torch.Tensor:
"""
该方法构造了一个单位矩阵,当原数据为3维数据时, shape为(4, 4), 然后将需要镜像的维度mat[dim][dim]处的值
赋值为-1, 将单位矩阵最后一列上,需要镜像的维度位置的值,赋值为该维度data对应的shape
mirror_dims : (1, 2)
offset_mask : tensor([0., 1., 1.]) torch.Size([3])
mat : tensor([[ 1., 0., 0., 0.],
[ 0., -1., 0., 112.],
[ 0., 0., -1., 128.],
[ 0., 0., 0., 1.]]) torch.Size([4, 4])
mirror_dims : (0, 1, 2)
offset_mask : tensor([1., 1., 1.]) torch.Size([3])
mat : tensor([[ -1., 0., 0., 160.],
[ 0., -1., 0., 112.],
[ 0., 0., -1., 128.],
[ 0., 0., 0., 1.]]) torch.Size([4, 4])
"""
"""
Create n dimensional matrix to for mirroring
Args:
cartesian_dims: number of cartesian dimensions
mirror_dims: dimensions to mirror
data_shape: shape of image
Returns:
Tensor: matrix for mirroring in homogeneous coordinated,
[cartesian_dims + 1, cartesian_dims + 1]
"""
mirror_dims = tuple(mirror_dims)
data_shape = list(data_shape)
homogeneous_dims = cartesian_dims + 1
mat = torch.eye(homogeneous_dims, dtype=torch.float)
# reflection
mat[[mirror_dims] * 2] = -1 # 将单位矩阵中需要镜像的维度mat[mirr_dim][mirr_dim]设置为-1
# add data shape to axis which were reflected
self_tensor = torch.zeros(cartesian_dims, dtype=torch.float) # tensor([0., 0., 0.])
index_tensor = torch.Tensor(mirror_dims).long() # tensor([需要镜像的维度])
src_tensor = torch.tensor([1] * len(mirror_dims), dtype=torch.float) # tensor([1.] * 需要镜像的维度)
offset_mask = self_tensor.scatter_(0, index_tensor, src_tensor) # 哪个dim需要镜像, 哪个位置上就是1, eg:dim=0:(tensor([1., 0.., 0..])), dim=(1, 2):tensor([0., 1., 1.])
mat[:-1, -1] = offset_mask * torch.tensor(data_shape) # 如公式所示, 将单位矩阵最后一列上,需要镜像的维度位置的值,赋值为该维度data对应的shape
return mat
def points_to_cartesian(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
"""
Transforms points in homogeneous coordinates back to cartesian
coordinates.
Args:
points: homogeneous points [N, in_dims], N number of points,
in_dims number of input dimensions (spatial dimensions + 1)
Returns:
List[Tensor]]: cartesian points [N, in_dims] = [N, dims]
"""
return [p[..., :-1] / p[..., -1][:, mirror_points] for p in points] # points.shape==(200, 4), 该处将前3列除以最后一列, 得到 形如 (200, 3) shape 的 points, 感觉这里没有必要,直接取p[..., :-1]就可以, 因为 p 最后一列 值都是1, 未曾更改过
def points2boxes(points: Tensor) -> Tensor:
"""
Convert points to boxes
points:(200, 3), 该方法将这200个点, 再重新组织成(100, 6) 即 (x1, y1, x2, y2, z1, z2)的形式
Args:
points: boxes need to be order as specified
order: [point_box_0, ... point_box_N/2] * 4
format of points: (x, y(, z)))[N, dims]
Returns:
Tensor: bounding boxes [N / 2, dims * 2]
"""
if points.nelement() > 0:
points0, points1 = points.split(points.shape[0] // 2)
boxes = torch.zeros(points.shape[0] // 2, points.shape[1] * 2).to(
device=points.device, dtype=points.dtype)
boxes[:, 0] = torch.min(points0[:, 0], points1[:, 0])
boxes[:, 1] = torch.min(points0[:, 1], points1[:, 1])
boxes[:, 2] = torch.max(points0[:, 0], points1[:, 0])
boxes[:, 3] = torch.max(points0[:, 1], points1[:, 1])
if boxes.shape[1] == 6:
boxes[:, 4] = torch.min(points0[:, 2], points1[:, 2])
boxes[:, 5] = torch.max(points0[:, 2], points1[:, 2])
return boxes
else:
return torch.tensor([]).view(-1, points.shape[1] * 2).to(points)