选择配置为CenterNet2_R50_1x.yaml
先解读测试过程,再分析训练。整体代码结构如下:
1.利用Resnet50生成五层特征图
features = self.backbone(images.tensor) # 代表(8,16,32,64,128)倍下采样
# 以输入(1, 3, 768, 1344)为例,第一层为(1, 256, 96, 168),,,(1, 256, 6, 11)
2.生成proposal
proposals, _ = self.proposal_generator(images, features, None)
3.roi_heads得到results
results, _ = self.roi_heads(images, features, proposals, None)
所有batch中目标所在点index
示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。
1.进行第一次分类与回归,(256,4)和(256,1),得到reg与agn_hm两个特征图
----------------------------------------------------------------------------
clss_per_level, reg_pred_per_level, agn_hm_pred_per_level = self.centernet_head(features)
----------------------------------------------------------------------------
## reg_pred_per_level:(1, 4, 96, 168)...(1, 4, 6, 11)
## clss_per_level在测试阶段为NoneType,agn_hm_pred_per_level:(1, 1, 96, 168)。。。
2.每层特征图上,得到绝对坐标值
----------------------------------------------------------------------------
grids = self.compute_grids(features) #(16128, 2)(4032, 2)...(66, 2)
----------------------------------------------------------------------------
##回归得到的4参数只是一个偏移值,需得到特征图所有点的绝对坐标,计算得到box:
h, w = feature.size()[-2:] # 96,168
shifts_x = torch.arange(0, w * self.strides[level], step=self.strides[level],dtype=torch.float32, device=feature.device)
# 0,8,16,24,32。。。,1336
shifts_y = torch.arange( 0, h * self.strides[level], step=self.strides[level],dtype=torch.float32, device=feature.device)
# 0,8,16,24,32。。。,760
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
grids_per_level = torch.stack((shift_x, shift_y), dim=1) + self.strides[level] // 2
# (768*1344,2),大概是4,12,20。。。1340这种,每个网格中心点
3.每层特征图的尺度大小
----------------------------------------------------------------------------
shapes_per_level = grids[0].new_tensor([(x.shape[2], x.shape[3]) for x in reg_pred_per_level])
----------------------------------------------------------------------------
# (96,168)(48., 84)..(6,11)
4.根据阈值,筛选前1000个 proposals
self.inference(images, clss_per_level, reg_pred_per_level, agn_hm_pred_per_level, grids)
即 proposals = self.predict_instances(grids, agn_hm_pred_per_level, reg_pred_per_level, images.image_sizes, [None for _ in agn_hm_pred_per_level])
----------------------------------------------------------------------------
即 self.predict_single_level(grids[l], logits_pred[l], reg_pred[l] * self.strides[l], image_sizes, agn_hm_pred[l], l, is_proposal=is_proposal))
----------------------------------------------------------------------------
# 将每层的logits_pred作为热图(heatmap),取阈值0.001并选前1000个目标,坐标与grid进行相加,得到每层的boxlist:
boxlist.scores = torch.sqrt(per_box_cls) # (1000)
boxlist.pred_boxes = Boxes(detections) # (1000,4)
boxlist.pred_classes = per_class # 1000 个 [0]
5.将5层结果做NMS
----------------------------------------------------------------------------
boxlists = self.nms_and_topK(boxlists)
----------------------------------------------------------------------------
整体代码如下,共经历三次级联网络
for k in range(self.num_cascade_stages):
if k > 0:
proposals = self._create_proposals_from_boxes(prev_pred_boxes, image_sizes)
if self.training:
proposals = self._match_and_label_boxes(proposals, k, targets)
predictions = self._run_stage(features, proposals, k) # tuple:(256,81)(256,4),4为xywh
prev_pred_boxes = self.box_predictor[k].predict_boxes(predictions, proposals)
head_outputs.append((self.box_predictor[k], predictions, proposals))
循环3次,每次将feature与proposal生成新的proposal,保存在 head_outputs 中。
下面分别展开各个函数:
主要是RoIPool,以及分类和回归
box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) # ([256, 256, 7, 7])
box_features = self.box_head[stage](box_features) # ([256, 1024])
return self.box_predictor[stage](box_features) # 全链接(1024,81)(1024,4)
3层:Linear( 1024, 81, bias=True),Linear( 1024, 4, bias=True)
_, proposal_deltas = predictions # ( 256,4 )
num_prop_per_image = [len(p) for p in proposals] # [256]
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # ([256, 4])
predict_boxes = self.box2box_transform.apply_deltas(proposal_deltas, proposal_boxes) #再次解码
利用Roi后的回归值,再次解码proposal,过程如下
def apply_deltas(self, deltas, boxes): #输入输出为同维tensor(256,4)
deltas = deltas.float() # ensure fp32 for decoding precision
boxes = boxes.to(deltas.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.weights # (10.0, 10.0, 5.0, 5.0)
dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
dh = deltas[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.scale_clamp)
dh = torch.clamp(dh, max=self.scale_clamp) # 4.135166556742356
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
x1 = pred_ctr_x - 0.5 * pred_w
y1 = pred_ctr_y - 0.5 * pred_h
x2 = pred_ctr_x + 0.5 * pred_w
y2 = pred_ctr_y + 0.5 * pred_h
pred_boxes = torch.stack((x1, y1, x2, y2), dim=-1)
return pred_boxes.reshape(deltas.shape)
在测试阶段这里没什么意义,输入等于输出。
// An highlighted block
var foo = 'bar';
1. 三次级联得分求平均([256, 81])
scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] # 对prediction中的score作 relu,得到3个(256,81)
scores = [sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
for scores_per_image in zip(*scores_per_stage)]
2. 与首次分类得分相乘
scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)] # (256,81)与(256,1)得到(256,81)
3.利用最后一次级联结果作解码,得到最终box,再后处理
predictor, predictions, proposals = head_outputs[-1]
boxes = predictor.predict_boxes(predictions, proposals) # ([256, 4])
pred_instances, _ = fast_rcnn_inference(boxes,scores, image_sizes,
predictor.test_score_thresh,
predictor.test_nms_thresh,
predictor.test_topk_per_image,) # 0.3 0.7 100
cls_scores = result.scores
image_thresh, _ = torch.kthvalue(cls_scores.cpu(),num_dets - post_nms_topk + 1)
# 例如 cls_scores中的 num_dets=2492,只需要前 post_nms_topk=256 个得分,可计算出阈值image_thresh
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
from torchvision.ops import boxes as box_ops
keep = box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
# keep(2492):tensor([2645, 249, 1724, ..., 2081, 2999, 3062], device='cuda:0')
# boxes为张量(3318,4), scores分数(3318), idxs为类别(3318个[0], threshold为0.9
boxlist = boxlist[keep]
主要作用:输入为(n,5)的ROI ,即感兴趣区域。根据大小,将其分配到三种尺度特征图上(5种也行),然后从原来的特征金字塔上抠出对应特征图。
from torchvision.ops import RoIPool
self.level_poolers = nn.ModuleList(RoIPool(output_size, spatial_scale=scale) for scale in scales)
level_assignments = assign_boxes_to_levels( box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level)
# (256): [0, 2, 0, 0, 1, 0, 1, 0, 2, 1, 0, 0, 0, 2...]
for level, pooler in enumerate(self.level_poolers):
inds = nonzero_tuple(level_assignments == level)[0] # (179)个序列
pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
# Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level))
# 其中,level_poolers为:
self.level_poolers = nn.ModuleList( ROIAlign( output_size, spatial_scale=scale, sampling_ratio=0, aligned=True ) for scale in scales )
## scale为1/8 到 1/128
from torchvision.ops import roi_align
class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
super().__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
from torchvision import __version__
version = tuple(int(x) for x in __version__.split(".")[:2])
# https://github.com/pytorch/vision/pull/2438
assert version >= (0, 7), "Require torchvision >= 0.7"
def forward(self, input, rois):
"""
Args:
input: NCHW images
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
"""
assert rois.dim() == 2 and rois.size(1) == 5
return roi_align(
input, # 对应的某层特征图
rois.to(dtype=input.dtype), # (n,5)第一维为该层索引,如 3
self.output_size,
self.spatial_scale,
self.sampling_ratio, # 一般为0
self.aligned, # 一般为True
)
比如:输入:(256,4)。输出:(256),即 [0,0,0,1,1,0,0,0,2,2,…]
def assign_boxes_to_levels(
box_lists: List[Boxes],
min_level: int,
max_level: int,
canonical_box_size: int,
canonical_level: int,
):
box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists])) # 2048 个box的面积开方
level_assignments = torch.floor(
canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
) # torch.log2() 函数值域为(-6,2), canonical_box_size = 224
# clamp level to (min, max), in case the box size is too large or too small
# for the available feature maps
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
return level_assignments.to(torch.int64) - min_level
训练损失共2部分
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
把标签映射成feature map的维度
def _get_ground_truth(self, grids, shapes_per_level, gt_instances):
'''
Input:
grids: list of tensors [(hl x wl, 2)]_l
shapes_per_level: list of tuples L x 2:
gt_instances: gt instances
Retuen:
pos_inds: N
labels: N
reg_targets: M x 4
flattened_hms: M x C or M x 1
N: number of objects in all images
M: number of pixels from all FPN levels
'''
# get positive pixel index
if not self.more_pos:
pos_inds, labels = self._get_label_inds(
gt_instances, shapes_per_level) # N, N :一个batch里,所有目标中心点所在索引 [516, 3692,7533,...55,209...,71433]
else:
pos_inds, labels = None, None
heatmap_channels = self.num_classes
L = len(grids)
num_loc_list = [len(loc) for loc in grids]
strides = torch.cat([
shapes_per_level.new_ones(num_loc_list[l]) * self.strides[l] \
for l in range(L)]).float() # M 19620: 14720*[8] + ... 240*[64] + 60*[128]
reg_size_ranges = torch.cat([
shapes_per_level.new_tensor(self.sizes_of_interest[l]).float().view(
1, 2).expand(num_loc_list[l], 2) for l in range(L)]) # M x 2 (19620*2): 14720*[0, 80] + ... 240*[256, 640, ] + 60*[512, 100000]
grids = torch.cat(grids, dim=0) # M x 2 (19620*2): (14720, 2), (3680, 2), (920, 2), ,,(60, 2 )
M = grids.shape[0]
reg_targets = []
flattened_hms = []
for i in range(len(gt_instances)): # images
boxes = gt_instances[i].gt_boxes.tensor # N x 4
area = gt_instances[i].gt_boxes.area() # N
gt_classes = gt_instances[i].gt_classes # N in [0, self.num_classes]
N = boxes.shape[0]
if N == 0:
reg_targets.append(grids.new_zeros((M, 4)) - INF)
flattened_hms.append(
grids.new_zeros((
M, 1 if self.only_proposal else heatmap_channels)))
continue
l = grids[:, 0].view(M, 1) - boxes[:, 0].view(1, N) # M x N (19620, 75)
t = grids[:, 1].view(M, 1) - boxes[:, 1].view(1, N) # M x N
r = boxes[:, 2].view(1, N) - grids[:, 0].view(M, 1) # M x N
b = boxes[:, 3].view(1, N) - grids[:, 1].view(M, 1) # M x N
reg_target = torch.stack([l, t, r, b], dim=2) # M x N x 4
centers = ((boxes[:, [0, 1]] + boxes[:, [2, 3]]) / 2) # N x 2
centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
strides_expanded = strides.view(M, 1, 1).expand(M, N, 2)
centers_discret = ((centers_expanded / strides_expanded).int() * \
strides_expanded).float() + strides_expanded / 2 # M x N x 2 目标中心点最近的网格坐标
is_peak = (((grids.view(M, 1, 2).expand(M, N, 2) - \
centers_discret) ** 2).sum(dim=2) == 0) # M x N
is_in_boxes = reg_target.min(dim=2)[0] > 0 # M x N
is_center3x3 = self.get_center3x3(
grids, centers, strides) & is_in_boxes # input: (M, 2) (N. 2) (M) --> M x N
is_cared_in_the_level = self.assign_reg_fpn(
reg_target, reg_size_ranges) # M x N reg_target(l,t,r,b)计算面积,跟size_ranges对比
reg_mask = is_center3x3 & is_cared_in_the_level # M x N
dist2 = ((grids.view(M, 1, 2).expand(M, N, 2) - \
centers_expanded) ** 2).sum(dim=2) # M x N
dist2[is_peak] = 0
radius2 = self.delta ** 2 * 2 * area # N
radius2 = torch.clamp(
radius2, min=self.min_radius ** 2)
weighted_dist2 = dist2 / radius2.view(1, N).expand(M, N) # M x N
reg_target = self._get_reg_targets(
reg_target, weighted_dist2.clone(), reg_mask, area) # M x 4
if self.only_proposal:
flattened_hm = self._create_agn_heatmaps_from_dist(
weighted_dist2.clone()) # M x 1 用min(dist,dim=1)把(M,N)映射为(M),即为每个特征点找到最近的gt,并返回距离
else: # 不执行
flattened_hm = self._create_heatmaps_from_dist(
weighted_dist2.clone(), gt_classes,
channels=heatmap_channels) # M x C
reg_targets.append(reg_target) # (M, 4)
flattened_hms.append(flattened_hm) # (M, 1)
# transpose im first training_targets to level first ones
reg_targets = _transpose(reg_targets, num_loc_list) # 5 * [64512, 4] [16128,4]...[66, 4]
flattened_hms = _transpose(flattened_hms, num_loc_list) # 5 * [64512, 1] [16128,1]...[66, 1]
for l in range(len(reg_targets)):
reg_targets[l] = reg_targets[l] / float(self.strides[l])
reg_targets = cat([x for x in reg_targets], dim=0) # MB x 4(85944, 4): 64512 + 16128 + ... + 66
flattened_hms = cat([x for x in flattened_hms], dim=0) # MB x C (85944, 1)
return pos_inds, labels, reg_targets, flattened_hms
def _get_label_inds(self, gt_instances, shapes_per_level):
'''
Inputs:
gt_instances: [n_i], sum n_i = N
shapes_per_level: L x 2 [(h_l, w_l)]_L
Returns:
pos_inds: N'
labels: N'
'''
pos_inds = []
labels = []
L = len(self.strides) # 5
B = len(gt_instances) # bs
shapes_per_level = shapes_per_level.long()
loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long() # [16128, 4032, 1008, 252, 66]
level_bases = []
s = 0
for l in range(L):
level_bases.append(s)
s = s + B * loc_per_level[l] # [0, 64512, 80640, 84672, 85680 ]
level_bases = shapes_per_level.new_tensor(level_bases).long() # [0, 64512, 80640, 84672, 85680 ]
strides_default = shapes_per_level.new_tensor(self.strides).float() # [ 8, 16, 32, 64, 128 ]
for im_i in range(B):
targets_per_im = gt_instances[im_i]
bboxes = targets_per_im.gt_boxes.tensor # n x 4: (x1, y1, x2, y2)
n = bboxes.shape[0]
centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2) # n x 2
centers = centers.view(n, 1, 2).expand(n, L, 2) # ( n, 5, 2 )
strides = strides_default.view(1, L, 1).expand(n, L, 2) # [ 8., 16., 32., 64., 128 ] --> ( n, 5, 2 )
centers_inds = (centers / strides).long() # n x 5 x 2
Ws = shapes_per_level[:, 1].view(1, L).expand(n, L) # ( n, 5 ) 把5层特征图的宽,单独拿出来
pos_ind = level_bases.view(1, L).expand(n, L) + \
im_i * loc_per_level.view(1, L).expand(n, L) + \
centers_inds[:, :, 1] * Ws + \
centers_inds[:, :, 0] # n x 5 : 把B个图片的5层特征图拉成直线,找到n个标签中心所在的索引
is_cared_in_the_level = self.assign_fpn_level(bboxes) # box 为绝对值 --> (n, 5):[True False... ] 根据标签面积大小,确定该目标在哪层特征图
pos_ind = pos_ind[is_cared_in_the_level].view(-1) # (n)
label = targets_per_im.gt_classes.view(
n, 1).expand(n, L)[is_cared_in_the_level].view(-1) # (n) class 绝对值
pos_inds.append(pos_ind) # n'
labels.append(label) # n'
pos_inds = torch.cat(pos_inds, dim=0).long() # 一个batch里,所有目标中心点所在索引 [516, 3692,7533,...55,209...,71433]
labels = torch.cat(labels, dim=0)
return pos_inds, labels # N, N
def assign_fpn_level(self, boxes):
'''
Inputs:
boxes: n x 4
size_ranges: L x 2
Return:
is_cared_in_the_level: n x L
'''
size_ranges = boxes.new_tensor(
self.sizes_of_interest).view(len(self.sizes_of_interest), 2) # 5 x 2 :[0, 80], [64, 160], [128, 320], [256, 640], [512, 10000000]
crit = ((boxes[:, 2:] - boxes[:, :2]) **2).sum(dim=1) ** 0.5 / 2 # n 宽*高,得到面积
n, L = crit.shape[0], size_ranges.shape[0]
crit = crit.view(n, 1).expand(n, L)
size_ranges_expand = size_ranges.view(1, L, 2).expand(n, L, 2)
is_cared_in_the_level = (crit >= size_ranges_expand[:, :, 0]) & \
(crit <= size_ranges_expand[:, :, 1])
return is_cared_in_the_level # n* 5 : [True False...]
def _get_reg_targets(self, reg_targets, dist, mask, area):
'''
reg_targets (M x N x 4): long tensor
dist (M x N)
is_*: M x N
'''
dist[mask == 0] = INF * 1.0
min_dist, min_inds = dist.min(dim=1) # M
reg_targets_per_im = reg_targets[
range(len(reg_targets)), min_inds] # M x N x 4 --> M x 4
reg_targets_per_im[min_dist == INF] = - INF
return reg_targets_per_im
def _create_agn_heatmaps_from_dist(self, dist):
'''
TODO (Xingyi): merge it with _create_heatmaps_from_dist
dist: M x N
return:
heatmaps: M x 1
'''
heatmaps = dist.new_zeros((dist.shape[0], 1)) # (M, 1)
heatmaps[:, 0] = torch.exp(-dist.min(dim=1)[0])
zeros = heatmaps < 1e-4
heatmaps[zeros] = 0
return heatmaps
def losses(
self, pos_inds, labels, reg_targets, flattened_hms,
logits_pred, reg_pred, agn_hm_pred):
'''
Inputs:
pos_inds: N
labels: N
reg_targets: M x 4
flattened_hms: M x C
logits_pred: M x C
reg_pred: M x 4
agn_hm_pred: M x 1 or None
N: number of positive locations in all images
M: number of pixels from all FPN levels
C: number of classes
'''
assert (torch.isfinite(reg_pred).all().item())
num_pos_local = pos_inds.numel()
num_gpus = get_world_size()
total_num_pos = reduce_sum(
pos_inds.new_tensor([num_pos_local])).item()
num_pos_avg = max(total_num_pos / num_gpus, 1.0)
losses = {}
if not self.only_proposal:
pos_loss, neg_loss = heatmap_focal_loss_jit(
logits_pred, flattened_hms, pos_inds, labels,
alpha=self.hm_focal_alpha,
beta=self.hm_focal_beta,
gamma=self.loss_gamma,
reduction='sum',
sigmoid_clamp=self.sigmoid_clamp,
ignore_high_fp=self.ignore_high_fp,
)
pos_loss = self.pos_weight * pos_loss / num_pos_avg
neg_loss = self.neg_weight * neg_loss / num_pos_avg
losses['loss_centernet_pos'] = pos_loss
losses['loss_centernet_neg'] = neg_loss
reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] >= 0).squeeze(1) # 这里选出正样本832个(gt有200)
reg_pred = reg_pred[reg_inds]
reg_targets_pos = reg_targets[reg_inds] # (832, 4)
reg_weight_map = flattened_hms.max(dim=1)[0] # grid到中心点的距离 (M: 81840)
reg_weight_map = reg_weight_map[reg_inds] # (832)
reg_weight_map = reg_weight_map * 0 + 1 \
if self.not_norm_reg else reg_weight_map # (832)* [ 1 ]
reg_norm = max(reduce_sum(reg_weight_map.sum()).item() / num_gpus, 1)
reg_loss = self.reg_weight * self.iou_loss(
reg_pred, reg_targets_pos, reg_weight_map,
reduction='sum') / reg_norm
losses['loss_centernet_loc'] = reg_loss
if self.with_agn_hm: # True
cat_agn_heatmap = flattened_hms.max(dim=1)[0] # M
agn_pos_loss, agn_neg_loss = binary_heatmap_focal_loss_jit(
agn_hm_pred, cat_agn_heatmap, pos_inds,
alpha=self.hm_focal_alpha,
beta=self.hm_focal_beta,
gamma=self.loss_gamma,
sigmoid_clamp=self.sigmoid_clamp,
ignore_high_fp=self.ignore_high_fp,
)
agn_pos_loss = self.pos_weight * agn_pos_loss / num_pos_avg
agn_neg_loss = self.neg_weight * agn_neg_loss / num_pos_avg
losses['loss_centernet_agn_pos'] = agn_pos_loss
losses['loss_centernet_agn_neg'] = agn_neg_loss
if self.debug:
print('losses', losses)
print('total_num_pos', total_num_pos)
return losses