关于PointPillars模型的损失,其沿袭了SECOND的做法。在类别分类损失和box回归损失的同时还包含角度损失和方向损失。
PointPillars中角度的编码使用真实值和Anchor的残差。但是再使用SmoothL1计算具体损失时会先计算出残差的sin值,再使用sin值来计算损失。
class ResidualCoder(object):
def __init__(self, code_size=7, encode_angle_by_sincos=False, **kwargs):
super().__init__()
self.code_size = code_size
self.encode_angle_by_sincos = encode_angle_by_sincos
if self.encode_angle_by_sincos:
self.code_size += 1
def encode_torch(self, boxes, anchors):
"""
Args:
boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
anchors: (N, 7 + C) [x, y, z, dx, dy, dz, heading or *[cos, sin], ...]
Returns:
"""
anchors[:, 3:6] = torch.clamp_min(anchors[:, 3:6], min=1e-5)
boxes[:, 3:6] = torch.clamp_min(boxes[:, 3:6], min=1e-5)
xa, ya, za, dxa, dya, dza, ra, *cas = torch.split(anchors, 1, dim=-1)
xg, yg, zg, dxg, dyg, dzg, rg, *cgs = torch.split(boxes, 1, dim=-1)
diagonal = torch.sqrt(dxa ** 2 + dya ** 2)
xt = (xg - xa) / diagonal
yt = (yg - ya) / diagonal
zt = (zg - za) / dza
dxt = torch.log(dxg / dxa)
dyt = torch.log(dyg / dya)
dzt = torch.log(dzg / dza)
if self.encode_angle_by_sincos:
rt_cos = torch.cos(rg) - torch.cos(ra)
rt_sin = torch.sin(rg) - torch.sin(ra)
rts = [rt_cos, rt_sin]
else:
rts = [rg - ra] #角度编码在此
cts = [g - a for g, a in zip(cgs, cas)]
return torch.cat([xt, yt, zt, dxt, dyt, dzt, *rts, *cts], dim=-1)
因为sin(a - b) = sinacosb-cosasinb,这里做了一下拆分。
box_preds_sin, reg_targets_sin = self.add_sin_difference(box_preds, box_reg_targets)
@staticmethod
def add_sin_difference(boxes1, boxes2, dim=6):
assert dim != -1
rad_pred_encoding = torch.sin(boxes1[..., dim:dim + 1]) * torch.cos(boxes2[..., dim:dim + 1])
rad_tg_encoding = torch.cos(boxes1[..., dim:dim + 1]) * torch.sin(boxes2[..., dim:dim + 1])
boxes1 = torch.cat([boxes1[..., :dim], rad_pred_encoding, boxes1[..., dim + 1:]], dim=-1)
boxes2 = torch.cat([boxes2[..., :dim], rad_tg_encoding, boxes2[..., dim + 1:]], dim=-1)
return boxes1, boxes2
@staticmethod
def smooth_l1_loss(diff, beta):
if beta < 1e-5:
loss = torch.abs(diff)
else:
n = torch.abs(diff)
loss = torch.where(n < beta, 0.5 * n ** 2 / beta, n - 0.5 * beta)
return loss
应该注意到以上角度损失的计算没有考虑到预测方向和实际方向完全反向的情况,因为sin(0)和sin(k*pi)值都为0。所以,作者又补充了一个方向上的损失。
定义了正反两个方向(num_bins=2),通过limit_period函数将角度限定再(0,2*pi)这个区间。
将朝向角[0,pi)转化为0,(pi,2*pi)转化为1,使用one-hot编码表示。
@staticmethod
def get_direction_target(anchors, reg_targets, one_hot=True, dir_offset=0, num_bins=2):
batch_size = reg_targets.shape[0]
anchors = anchors.view(batch_size, -1, anchors.shape[-1])
rot_gt = reg_targets[..., 6] + anchors[..., 6]
offset_rot = common_utils.limit_period(rot_gt - dir_offset, 0, 2 * np.pi)
dir_cls_targets = torch.floor(offset_rot / (2 * np.pi / num_bins)).long()
dir_cls_targets = torch.clamp(dir_cls_targets, min=0, max=num_bins - 1)
if one_hot:
dir_targets = torch.zeros(*list(dir_cls_targets.shape), num_bins, dtype=anchors.dtype,
device=dir_cls_targets.device)
dir_targets.scatter_(-1, dir_cls_targets.unsqueeze(dim=-1).long(), 1.0)
dir_cls_targets = dir_targets
return dir_cls_targets
def limit_period(val, offset=0.5, period=np.pi):
val, is_numpy = check_numpy_to_torch(val)
ans = val - torch.floor(val / period + offset) * period
return ans.numpy() if is_numpy else ans
使用交叉熵损失来计算方向损失。
class WeightedCrossEntropyLoss(nn.Module):
"""
Transform input to fit the fomation of PyTorch offical cross entropy loss
with anchor-wise weighting.
"""
def __init__(self):
super(WeightedCrossEntropyLoss, self).__init__()
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor):
"""
Args:
input: (B, #anchors, #classes) float tensor.
Predited logits for each class.
target: (B, #anchors, #classes) float tensor.
One-hot classification targets.
weights: (B, #anchors) float tensor.
Anchor-wise weights.
Returns:
loss: (B, #anchors) float tensor.
Weighted cross entropy loss without reduction
"""
input = input.permute(0, 2, 1)
target = target.argmax(dim=-1)
loss = F.cross_entropy(input, target, reduction='none') * weights
return loss