class RoIHeads(torch.nn.Module):
def forward(self,
features, # type: Dict[str, Tensor]
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
if self.training:
# 划分正负样本,统计对应gt的标签以及边界框回归信息
proposals, labels, regression_targets = self.select_training_samples(proposals, targets)
else:
labels = None
regression_targets = None
# 将采集样本通过Multi-scale RoIAlign pooling层
# box_features_shape: [num_proposals, channel, height, width]
box_features = self.box_roi_pool(features, proposals, image_shapes)
# 通过roi_pooling后的两层全连接层
# box_features_shape: [num_proposals, representation_size]
box_features = self.box_head(box_features)
# 接着分别预测目标类别和边界框回归参数
class_logits, box_regression = self.box_predictor(box_features)
result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
losses = {
}
if self.training:
assert labels is not None and regression_targets is not None
loss_classifier, loss_box_reg = fastrcnn_loss(
class_logits, box_regression, labels, regression_targets)
losses = {
"loss_classifier": loss_classifier,
"loss_box_reg": loss_box_reg
}
else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes)
for i in range(num_images):
result.append(
{
"boxes": boxes[i],
"labels": labels[i],
"scores": scores[i],
}
)
return result, losses
fastrcnn_loss
函数fastrcnn_loss
函数位于roi_head.py
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Arguments:
class_logits : 预测类别概率信息,shape=[num_anchors, num_classes]
box_regression : 预测边目标界框回归信息
labels : 真实类别信息
regression_targets : 真实目标边界框信息
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
# 计算类别损失信息
classification_loss = F.cross_entropy(class_logits, labels)
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
# 返回标签类别大于0的索引
# sampled_pos_inds_subset = torch.nonzero(torch.gt(labels, 0)).squeeze(1)
sampled_pos_inds_subset = torch.where(torch.gt(labels, 0))[0]
# 返回标签类别大于0位置的类别信息
labels_pos = labels[sampled_pos_inds_subset]
# shape=[num_proposal, num_classes]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, -1, 4)
# 计算边界框损失信息
box_loss = det_utils.smooth_l1_loss(
# 获取指定索引proposal的指定类别box信息
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
size_average=False,
) / labels.numel()
return classification_loss, box_loss
def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
# cond = n < beta
cond = torch.lt(n, beta)
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()