这个是 SA-SSD
的整体网络,由这几个部分组成:
在之后会详细分析每个部分,先来看一下整体的网络:(先看一下有哪些函数,具体的函数内容先省去了)
class SingleStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
def __init__(self,
backbone,
neck=None,
bbox_head=None,
extra_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
if bbox_head is not None:
self.rpn_head = builder.build_single_stage_head(bbox_head)
if extra_head is not None:
self.extra_head = builder.build_single_stage_head(extra_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained)
@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
def merge_second_batch(self, batch_args):
return ret
def forward_train(self, img, img_meta, **kwargs):
return losses
def forward_test(self, img, img_meta, **kwargs):
return results
代码分析:
def __init__(self,
backbone,
neck=None,
bbox_head=None,
extra_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStageDetector, self).__init__()
# 初始化 Backbone
self.backbone = builder.build_backbone(backbone)
# 初始化 neck
if neck is not None:
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
# 初始化 head
if bbox_head is not None:
self.rpn_head = builder.build_single_stage_head(bbox_head)
# 初始化 extra-head
if extra_head is not None:
self.extra_head = builder.build_single_stage_head(extra_head)
# 传入 cfg 中的参数
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# 初始化权重
self.init_weights(pretrained)
初始化部分都是一样的,点进去这些函数,就会发现其实都是通过 cfg
文件中的配置 分别初始化这些部分,最后都会进到这个 obj_from_dict
函数。
# 根据字典型变量info去指定初始化一个parrent类对象
# 说白了,就是字典型变量中储存了类的初始化变量。核心调用是getattr
# 总之,obj_from_dict是一种做指定初始化的功能函数
def obj_from_dict(info, parent=None, default_args=None):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
parent (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
any type: Object built from the dict.
"""
# 首先,判断info是不是字典,而且里面必须包含type关键字
# 默认参数也要检查是字典或者为None
assert isinstance(info, dict) and 'type' in info
assert isinstance(default_args, dict) or default_args is None
args = info.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if parent is not None:
obj_type = getattr(parent, obj_type)
else:
obj_type = sys.modules[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but '
f'got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args) # 传入arg里面的参数 相当于实例化了这个类
刚开始看这个函数没整明白,细细看了一下,起始就是根据 cfg 中 设置,找到所要初始化的类,然后再传进去 cfg 中的参数,举个栗子:
neck=dict(
type='SpMiddleFHD',
output_shape=[40, 1600, 1408],
num_input_features=4,
num_hidden_features=64 * 5,
),
这是初始化 neck
,cfg 文件中的配置,首先根据 type='SpMiddleFHD'
找到 SpMiddleFHD
这个类,然后再根据 cfg 中的 参数 实例化这个类。此时
return obj_type(**args)
就相当于:
return SpMiddleFHD(output_shape=[40, 1600, 1408], num_input_features=4, num_hidden_features=64 * 5)
ok, 其他的部分的初始化以此类推,都是这么实现的。应该本身代码是基于 mmdetection
实现的,然后 mmdetection
中就是这么实现的,恩,看懂了就行,以后自己再写代码的时候,也可以这么写,也很方便简洁。
然后看一下前向传递的函数:注释也在代码里面了
# img.shape [B, 3, 384, 1248]
# img_meta: dict
# img_meta[0]:
# img_shape : tuple (375, 1242, 3)
# sample_idx
# calib
# kwargs:
# 1. anchors list: len(anchors) = B
# 2. voxels list: len(voxels) = B
# 3. coordinates list: len(coordinates) = B
# 4. num_points list: len(num_points) = B
# 5. anchor_mask list: len(anchor_mask) = B
# 6. gt_labels list: len(gt_labels) = B
# 7. gt_bboxes list: len(gt_bboxes) = B
def forward_train(self, img, img_meta, **kwargs):
# --------------------------------------------------------------------------
# from mmdet.datasets.kitti_utils import draw_lidar
# f = draw_lidar(kwargs["voxels"][0].cpu().numpy(), show=True) # 显示 所有点云
# --------------------------------------------------------------------------
batch_size = len(img_meta) # B
ret = self.merge_second_batch(kwargs)
# vx 就是 ret['voxels']
vx = self.backbone(ret['voxels'], ret['num_points'])
# x.shape = [2, 256, 200, 176]
# conv6.shape = [2, 256, 200, 176]
# point_misc : tuple, shape = 3
# : 1. point_mean : shape [N,4] , [:,0] 是 Batch number
# : 2. point_cls : shape [N,1]
# : 3. point_reg : shape [N.3]
(x, conv6), point_misc = self.neck(vx, ret['coordinates'], batch_size)
losses = dict()
aux_loss = self.neck.aux_loss(*point_misc, gt_bboxes=ret['gt_bboxes'])
losses.update(aux_loss)
# RPN forward and loss
if self.with_rpn:
# rpn_outs : tuple, size = 3
# : 1. box_preds : shape [N, 200, 176, 14]
# : 2. cls_preds : shape [N, 200, 176, 2]
# : 3. dir_cls_preds : shape [N, 200, 176, 4]
rpn_outs = self.rpn_head(x)
# rpn_outs : tuple, shape = 8
rpn_loss_inputs = rpn_outs + (ret['gt_bboxes'], ret['gt_labels'], ret['anchors'], ret['anchors_mask'], self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses)
# guided_anchors.shape :
# [num_of_guided_anchors, 7]
# + [num_of_gt_bboxes, 7]
# ----------------------------
# = [all_num, 7]
guided_anchors = self.rpn_head.get_guided_anchors(*rpn_outs, ret['anchors'], ret['anchors_mask'], ret['gt_bboxes'], thr=0.1)
else:
raise NotImplementedError
# bbox head forward and loss
if self.extra_head:
bbox_score = self.extra_head(conv6, guided_anchors)
refine_loss_inputs = (bbox_score, ret['gt_bboxes'], ret['gt_labels'], guided_anchors, self.train_cfg.extra)
refine_losses = self.extra_head.loss(*refine_loss_inputs)
losses.update(refine_losses)
return losses
首先传进来的参数 会经过 merge_second_batch()
这个函数,看一下:
def merge_second_batch(self, batch_args):
ret = {}
for key, elems in batch_args.items():
if key in [
'voxels', 'num_points',
]:
ret[key] = torch.cat(elems, dim=0)
elif key == 'coordinates':
coors = []
for i, coor in enumerate(elems): # coor.shape : torch.Size([19480, 3])
coor_pad = F.pad(
coor, [1, 0, 0, 0],
mode='constant',
value=i) # 理解 https://blog.csdn.net/jorg_zhao/article/details/105295686
coors.append(coor_pad)
ret[key] = torch.cat(coors, dim=0)
elif key in [
'img_meta', 'gt_labels', 'gt_bboxes',
]:
ret[key] = elems
else:
ret[key] = torch.stack(elems, dim=0)
return ret
主要就是根据 key
把 batch
合并了,这个没什么问题,注意有这么一步:
coor_pad = F.pad(
coor, [1, 0, 0, 0],
mode='constant',
value=i)
coors.append(coor_pad)
这里 F.pad
的用法见: F.pad
目的就是给 coordinates
多加一个维度 (eg: i = 0,1, …),来保存 Batch
然后就是构建 loss
了,总共由三部分组成 :
l o s s _ a l l = a u g _ l o s s + r p n _ l o s s + e x t r a _ h e a d _ l o s s loss\_all =aug\_loss + rpn\_loss + extra\_head\_loss loss_all=aug_loss+rpn_loss+extra_head_loss
之后每部分 loss 的 具体组成 在后面也会具体分析。
在 Auxiliary Network 中, 需要分割出 前景点 和 背景点,首先需要生成前景点和背景点的 label
def pts_in_boxes3d(pts, boxes3d):
N = len(pts)
M = len(boxes3d)
pts_in_flag = torch.IntTensor(M, N).fill_(0)
reg_target = torch.FloatTensor(N, 3).fill_(0)
points_op_cpu.pts_in_boxes3d(pts.contiguous(), boxes3d.contiguous(), pts_in_flag, reg_target)
return pts_in_flag, reg_target
其中:
pts_in_flag : [M, N] , pts 在 bbox 中,则 mask = 1
疑惑 :reg_target : [N, 3], 值是什么?又是怎么得到的?
需要解决上面一个疑惑,就需要弄懂这个函数 points_op_cpu.pts_in_boxes3d
。这个函数在 mmdet / ops / points_op / src / points_op.cpp
中,来看一下:
int pts_in_boxes3d_cpu(at::Tensor pts, at::Tensor boxes3d, at::Tensor pts_flag, at::Tensor reg_target){
// param pts: (N, 3)
// param boxes3d: (M, 7) [x, y, z, h, w, l, ry]
// param pts_flag: (M, N)
// param reg_target: (N, 3), center offsets
CHECK_CONTIGUOUS(pts_flag);
CHECK_CONTIGUOUS(pts);
CHECK_CONTIGUOUS(boxes3d);
CHECK_CONTIGUOUS(reg_target);
long boxes_num = boxes3d.size(0);
long pts_num = pts.size(0);
int * pts_flag_flat = pts_flag.data<int>();
float * pts_flat = pts.data<float>();
float * boxes3d_flat = boxes3d.data<float>();
float * reg_target_flat = reg_target.data<float>();
// memset(assign_idx_flat, -1, boxes_num * pts_num * sizeof(int));
// memset(reg_target_flat, 0, pts_num * sizeof(float));
// 这里相当于把 tensor 给展开了遍历 (或者说铺平了?更好理解。懂就好)
int i, j, cur_in_flag;
for (i = 0; i < boxes_num; i++){
for (j = 0; j < pts_num; j++){
cur_in_flag = pt_in_box3d_cpu(pts_flat[j * 3], pts_flat[j * 3 + 1], pts_flat[j * 3 + 2], boxes3d_flat[i * 7],
boxes3d_flat[i * 7 + 1], boxes3d_flat[i * 7 + 2], boxes3d_flat[i * 7 + 3],
boxes3d_flat[i * 7 + 4], boxes3d_flat[i * 7 + 5], boxes3d_flat[i * 7 + 6]);
pts_flag_flat[i * pts_num + j] = cur_in_flag;
if(cur_in_flag==1){
reg_target_flat[j*3] = pts_flat[j*3] - boxes3d_flat[i*7];
reg_target_flat[j*3+1] = pts_flat[j*3+1] - boxes3d_flat[i*7+1];
reg_target_flat[j*3+2] = pts_flat[j*3+2] - (boxes3d_flat[i*7+2] + boxes3d_flat[i*7+3] / 2.0);
}
}
}
return 1;
}
其实已经可以大致理解这个函数在干啥了,通过两层循环遍历,判断点云中的所有点是否在所给定的 bbox 中,如果在 bbox 中, 那就将 该点的值 - bbox 中心点的值
,就是 reg_target
, 用公式表示就是:
r e g _ t a r g e t = P i ( x , y , z ) − P c e n t e r ( x , y , z ) reg\_target =P_{i}(x, y, z) -P_{center}(x,y,z) reg_target=Pi(x,y,z)−Pcenter(x,y,z)
ok,上面的疑问也解开了
这部分是整个网络的 head
部分,先简单列出来,然后来具体分析一下。
class SSDRotateHead(nn.Module):
def __init__(self,
num_class=1,
num_output_filters=768,
num_anchor_per_loc=2,
use_sigmoid_cls=True,
encode_rad_error_by_sin=True,
use_direction_classifier=True,
box_coder='GroundBox3dCoder',
box_code_size=7,
):
super(SSDRotateHead, self).__init__()
self._num_class = num_class
self._num_anchor_per_loc = num_anchor_per_loc
self._use_direction_classifier = use_direction_classifier
self._use_sigmoid_cls = use_sigmoid_cls
self._encode_rad_error_by_sin = encode_rad_error_by_sin
self._use_direction_classifier = use_direction_classifier
self._box_coder = getattr(boxCoders, box_coder)()
self._box_code_size = box_code_size
self._num_output_filters = num_output_filters
if use_sigmoid_cls: # True
num_cls = num_anchor_per_loc * num_class # 2 * 1
else:
num_cls = num_anchor_per_loc * (num_class + 1)
self.conv_cls = nn.Conv2d(num_output_filters, num_cls, 1)
self.conv_box = nn.Conv2d(
num_output_filters, num_anchor_per_loc * box_code_size, 1)
if use_direction_classifier:
self.conv_dir_cls = nn.Conv2d(
num_output_filters, num_anchor_per_loc * 2, 1)
def add_sin_difference(self, boxes1, boxes2):
def get_direction_target(self, anchors, reg_targets, use_one_hot=True):
def prepare_loss_weights(self, labels,
pos_cls_weight=1.0,
neg_cls_weight=1.0,
loss_norm_type='NormByNumPositives',
dtype=torch.float32):
def create_loss(self,
box_preds, # torch.Size([2, 200, 176, 14])
cls_preds, # torch.Size([2, 200, 176, 2])
cls_targets, # torch.Size([2, 70400])
cls_weights, # torch.Size([2, 70400])
reg_targets, # torch.Size([2, 70400, 7])
reg_weights, # torch.Size([2, 70400])
num_class, # 1
use_sigmoid_cls=True, # True
encode_rad_error_by_sin=True, # True
box_code_size=7): # 7
def forward(self, x):
def get_guided_anchors(self, box_preds, cls_preds, dir_cls_preds, anchors, anchors_mask, gt_bboxes, thr=.1):
首先看一下 前向传递 forward
函数 :
def forward(self, x): # torch.Size([2, 256, 200, 176])
box_preds = self.conv_box(x)
cls_preds = self.conv_cls(x)
# [N, C, y(H), x(W)]
box_preds = box_preds.permute(0, 2, 3, 1).contiguous() # torch.Size([2, 200, 176, 14])
cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous() # torch.Size([2, 200, 176, 2])
if self._use_direction_classifier:
dir_cls_preds = self.conv_dir_cls(x)
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous() # torch.Size([2, 200, 176, 4])
return box_preds, cls_preds, dir_cls_preds
输入就是经过 backbone
得到的 feature map
, 然后分成两支,分别预测bbox
和物体的类别。
看一下,loss
是怎么构建的:
# input
# box_preds : torch.Size([2, 200, 176, 14])
# cls_preds : torch.Size([2, 200, 176, 2])
# gt_bboxes : list:len(gt_bboxes) = B , gt_bboxes[0].shape = torch.Size([num_of_gt_bboxes, 7])
# anchor : torch.Size([2, 70400, 7])
# anchor_mask : torch.Size([2, 70400])
# cfg : from car_cfg.py / train_cfg
def loss(self, box_preds, cls_preds, dir_cls_preds, gt_bboxes, gt_labels, anchors, anchors_mask, cfg):
batch_size = box_preds.shape[0]
# ADD----------------------------------------------------------------------------------------------
add_for_test = False
add_for_pkl = False
# for show gt_bboxes
if add_for_test == True:
bbox3d_for_test = gt_bboxes[0].cpu().numpy()
draw_gt_boxes3d_for_test(center_to_corner_box3d(bbox3d_for_test), draw_text=True, show=True)
# for vis anchor
if add_for_pkl == True:
pkl_data = {}
pkl_data['anchors'] = anchors
pkl_data['anchors_mask'] = anchors_mask
import pickle
with open("/home/seivl/pkl_data.pkl", 'wb') as fo:
pickle.dump(pkl_data, fo)
#-----------------------------------------------------------------------------------------------
# 第一个 create_target_torch 是函数
# 后面变量相当于传参数 进这个函数
# targets 是 reg 的 target
labels, targets, ious = multi_apply(create_target_torch,
anchors, gt_bboxes,
anchors_mask, gt_labels,
similarity_fn=getattr(iou3d_utils, cfg.assigner.similarity_fn)(),
box_encoding_fn = second_box_encode,
matched_threshold=cfg.assigner.pos_iou_thr,
unmatched_threshold=cfg.assigner.neg_iou_thr,
box_code_size=self._box_code_size)
labels = torch.stack(labels,)
targets = torch.stack(targets)
# 生成 cls 和 reg 的权重
cls_weights, reg_weights, cared = self.prepare_loss_weights(labels)
# 生成 cls 的 target
cls_targets = labels * cared.type_as(labels)
# 构建 loss
# 具体解析见下
loc_loss, cls_loss = self.create_loss(
box_preds=box_preds,
cls_preds=cls_preds,
cls_targets=cls_targets,
cls_weights=cls_weights,
reg_targets=targets,
reg_weights=reg_weights,
num_class=self._num_class,
encode_rad_error_by_sin=self._encode_rad_error_by_sin,
use_sigmoid_cls=self._use_sigmoid_cls,
box_code_size=self._box_code_size,
)
loc_loss_reduced = loc_loss / batch_size
loc_loss_reduced *= 2 # loc_loss 的权重
cls_loss_reduced = cls_loss / batch_size
cls_loss_reduced *= 1
loss = loc_loss_reduced + cls_loss_reduced
if self._use_direction_classifier:
# 生成与 dir_cls_preds 对应的真值 dir_labels
dir_labels = self.get_direction_target(anchors, targets, use_one_hot=False).view(-1)
dir_logits = dir_cls_preds.view(-1, 2)
# 设置权值是为了仅仅考虑 labels > 0 的目标(即车这一类)
weights = (labels > 0).type_as(dir_logits)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
# 使用交叉熵做朝向预测的误差损失函数
dir_loss = weighted_cross_entropy(dir_logits, dir_labels,
weight=weights.view(-1),
avg_factor=1.)
dir_loss_reduced = dir_loss / batch_size
dir_loss_reduced *= .2
loss += dir_loss_reduced
return dict(rpn_loc_loss=loc_loss_reduced, rpn_cls_loss=cls_loss_reduced, rpn_dir_loss=dir_loss_reduced)
里面有一个很重要的函数 create_target_torch
,是用来生成 label
用的, 具体分析在后面。
具体的 loss 构建函数:
def create_loss(self,
box_preds, # torch.Size([2, 200, 176, 14])
cls_preds, # torch.Size([2, 200, 176, 2])
cls_targets, # torch.Size([2, 70400])
cls_weights, # torch.Size([2, 70400])
reg_targets, # torch.Size([2, 70400, 7])
reg_weights, # torch.Size([2, 70400])
num_class, # 1
use_sigmoid_cls=True, # True
encode_rad_error_by_sin=True, # True
box_code_size=7): # 7
batch_size = int(box_preds.shape[0]) # B = 2
box_preds = box_preds.view(batch_size, -1, box_code_size) # torch.Size([2, 70400, 7])
if use_sigmoid_cls:
cls_preds = cls_preds.view(batch_size, -1, num_class) # torch.Size([2, 70400, 1])
else:
cls_preds = cls_preds.view(batch_size, -1, num_class + 1)
one_hot_targets = one_hot(
cls_targets, depth=num_class + 1, dtype=box_preds.dtype) # torch.Size([2, 70400, 2])
if use_sigmoid_cls:
one_hot_targets = one_hot_targets[..., 1:] # torch.Size([2, 70400, 1])
if encode_rad_error_by_sin:
box_preds, reg_targets = self.add_sin_difference(box_preds, reg_targets)
# torch.Size([2, 70400, 7])
# torch.Size([2, 70400, 7])
loc_losses = weighted_smoothl1(box_preds, reg_targets, beta=1 / 9., \
weight=reg_weights[..., None], avg_factor=1.)
cls_losses = weighted_sigmoid_focal_loss(cls_preds, one_hot_targets, \
weight=cls_weights[..., None], avg_factor=1.)
return loc_losses, cls_losses
主要在 create_target_torch
这个函数中,注释和解析如下,
这段代码的作用 主要是为了:
生成 anchor 的 label
bbox 回归的 target
# all_anchors : torch.Size([70400, 7])
# gt_boxes : torch.Size([num_of_gt_bbox, 7])
# anchor_mask : torch.Size(70400,)
# gt_classes : num_of_gt_bbox eg: 14
# similarity_fn :
# box_encoding_fn :
# matched_threshold : 0.6
# unmatched_threshold : 0.45
# positive_fraction : None
# norm_by_num_examples : False
# box_code_size : 7
def create_target_torch(all_anchors,
gt_boxes,
anchor_mask,
gt_classes,
similarity_fn,
box_encoding_fn,
matched_threshold=0.6,
unmatched_threshold=0.45,
positive_fraction=None,
rpn_batch_size=300,
norm_by_num_examples=False,
box_code_size=7):
# torch.set_printoptions(threshold=np.inf)
# 这个函数的作用是将 anchor_mask 映射回 anchor
def _unmap(data, count, inds, fill=0):
# ----------------------------
# data : label
# count : anchor.shape
# inds : mask
# ---------------------------
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
if data.dim() == 1:
ret = data.new_full((count,), fill)
ret[inds] = data
else:
new_size = (count,) + data.size()[1:]
ret = data.new_full(new_size, fill)
ret[inds, :] = data
return ret
# value: 70400
total_anchors = all_anchors.shape[0]
# go
if anchor_mask is not None:
#inds_inside = np.where(anchors_mask)[0] # prune_anchor_fn(all_anchors)
# value: 22007
anchors = all_anchors[anchor_mask, :]
if not isinstance(matched_threshold, float):
matched_threshold = matched_threshold[anchor_mask]
if not isinstance(unmatched_threshold, float):
unmatched_threshold = unmatched_threshold[anchor_mask]
else:
anchors = all_anchors
#inds_inside = None
# 22007
num_inside = len(torch.nonzero(anchor_mask)) if anchor_mask is not None else total_anchors
if gt_classes is None:
gt_classes = torch.ones([gt_boxes.shape[0]], dtype=torch.int64, device=gt_boxes.device)
# Compute anchor labels:
# label=1 is positive, 0 is negative, -1 is don't care (ignore)
# shape = [22007,] value = -1
labels = torch.empty((num_inside,), dtype=torch.int64, device=gt_boxes.device).fill_(-1)
gt_ids = torch.empty((num_inside,), dtype=torch.int64, device=gt_boxes.device).fill_(-1)
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
# Compute overlaps between the anchors and the gt boxes overlaps
# 计算 anchor 和 gt_bbox 的交并比
anchor_by_gt_overlap = similarity_fn(anchors, gt_boxes) # torch.Size([22007, 14])
# add for test
# for_test_anchor_by_gt_overlap = similarity_fn(anchors[9300:9303,:], gt_boxes)
# Map from anchor to gt box that has highest overlap
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
# shape:22007
# 计算每个 anchor 和 gt_bbox 的 iou 最大值的索引
# 这里的 dim = 1 就是第1个维度 22007
# For each anchor, amount of overlap with most overlapping gt box
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_inside),
anchor_to_gt_argmax]
# 计算每个 anchor 和 gt_bbox 的 iou 最大值
# Map from gt box to an anchor that has highest overlap
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
# 计算每个 gt_bbox 和 anchor 的 iou 最大值的索引
# 这里的 dim = 0 就是第0个维度
# shape: 14
# For each gt box, amount of overlap with most overlapping anchor
gt_to_anchor_max = anchor_by_gt_overlap[
gt_to_anchor_argmax,
torch.arange(anchor_by_gt_overlap.shape[1])]
# 计算每个 gt_bbox 和 anchor 的 iou 最大值
# must remove gt which doesn't match any anchor.
empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1
# Find all anchors that share the max overlap amount
# (this includes many ties)
anchors_with_max_overlap = torch.nonzero(
anchor_by_gt_overlap == gt_to_anchor_max)[:,0]
# 找到和 gt_bbox 有最大 iou 的 anchor
# tensor([ 6287, 7063, 9302, 9530, 9571, 10225, 11481, 13080, 14509, 15080,
# 15082, 15293, 18273, 18740, 21316], device='cuda:0')
# for test
# for_test_anchors_with_max_overlap = torch.nonzero(
# for_test_anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
# Fg label: for each gt use anchors with highest overlap
# (including ties)
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
# 15
# tensor([ 6, 10, 12, 11, 13, 7, 9, 5, 3, 2, 2, 8, 1, 0, 4],
# device='cuda:0')
# 找到这些 anchor 和 哪些 gt_bbox 对应
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] # 做对应的label 最大 iou 的 anchoor 置为 1
gt_ids[anchors_with_max_overlap] = gt_inds_force # 保存 对应的 gt 的 序号
# Fg label: above threshold IOU
pos_inds = anchor_to_gt_max >= matched_threshold # 找所有 anchor 大于阈值的
gt_inds = anchor_to_gt_argmax[pos_inds] # 记录这些 anchor 对应 gt_bbox 的下标
# 有 67 个 ,anchor 和 gt_bbox 的 iou 大于阈值
# tensor([ 6, 6, 6, 6, 6, 6, 10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 11, 11,
# 12, 11, 11, 13, 13, 11, 13, 13, 7, 7, 7, 7, 7, 9, 9, 9, 9, 5,
# 5, 5, 5, 5, 3, 3, 3, 3, 2, 2, 2, 2, 8, 8, 8, 8, 1, 1,
# 1, 1, 1, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4], device='cuda:0')
labels[pos_inds] = gt_classes[gt_inds] # 对应的 label 设置为 1
gt_ids[pos_inds] = gt_inds # 保存 对应的 gt 的 序号
# bg_inds = np.where(anchor_to_gt_max < unmatched_threshold)[0]
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
# 找到 小于阈值的 anchor 的 index
else:
bg_inds = torch.arange(num_inside)
#fg_inds = np.where(labels > 0)[0]
fg_inds = torch.nonzero(labels > 0)[:, 0]
# 找到所有前景 anchor 的 index
# tensor([ 6283, 6285, 6287, 6289, 6291, 6498, 6852, 6854, 7061, 7063,
# 7268, 7270, 8883, 9094, 9300, 9302, 9324, 9326, 9508, 9530,
# 9532, 9571, 9573, 9736, 9777, 9779, 9827, 10028, 10225, 10227,
# 10424, 11481, 11483, 11757, 11759, 13078, 13080, 13082, 13084, 13366,
# 14267, 14509, 14511, 14750, 15078, 15080, 15082, 15084, 15291, 15293,
# 15295, 15553, 18009, 18269, 18271, 18273, 18275, 18493, 18495, 18738,
# 18740, 18742, 21312, 21314, 21316, 21318, 21389], device='cuda:0')
# subsample positive labels if we have too many
if positive_fraction is not None:
num_fg = int(positive_fraction * rpn_batch_size)
if len(fg_inds) > num_fg:
disable_inds = npr.choice(
fg_inds, size=(len(fg_inds) - num_fg), replace=False)
labels[disable_inds] = -1
#fg_inds = np.where(labels > 0)[0]
fg_inds = torch.where(labels > 0)[:, 0]
# subsample negative labels if we have too many
# (samples with replacement, but since the set of bg inds is large most
# samples will not have repeats)
num_bg = rpn_batch_size - np.sum(labels > 0)
# print(num_fg, num_bg, len(bg_inds) )
if len(bg_inds) > num_bg:
enable_inds = bg_inds[npr.randint(len(bg_inds), size=num_bg)]
labels[enable_inds] = 0
else:
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
labels[:] = 0
else:
labels[bg_inds] = 0 # 背景点的 label 设置为 0
# re-enable anchors_with_max_overlap
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
# 生成 target
bbox_targets = torch.zeros(
(num_inside, box_code_size), dtype=all_anchors.dtype, device=gt_boxes.device) # torch.Size([22007, 7])
# 对前景的 anchor 进行编码
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
bbox_targets[fg_inds, :] = box_encoding_fn(
gt_boxes[anchor_to_gt_argmax[fg_inds], :], anchors[fg_inds, :])
# bbox_targets[fg_inds, :].shape : torch.Size([67, 7])
bbox_outside_weights = torch.zeros((num_inside,), dtype=all_anchors.dtype, device=gt_boxes.device)
# uniform weighting of examples (given non-uniform sampling)
if norm_by_num_examples:
num_examples = torch.sum(labels >= 0) # neg + pos
num_examples = np.maximum(1.0, num_examples)
bbox_outside_weights[labels > 0] = 1.0 / num_examples
else:
bbox_outside_weights[labels > 0] = 1.0
# Map up to original set of anchors
if anchor_mask is not None:
labels = _unmap(labels, total_anchors, anchor_mask, fill=-1)
bbox_targets = _unmap(bbox_targets, total_anchors, anchor_mask, fill=0)
return (labels, bbox_targets, anchor_to_gt_max)
# labels.shape : torch.Size([70400,])
# bbox_targets.shape : torch.Size([70400, 7])
# anchor_to_gt_max : 22007
# 关于 label
# 前景是 1
# 背景是 0
# 没用的是 -1
ok 未完待续。