所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
样本的不同是目标检测与普通分类任务最大的不同点。在分类任务中,每张图片被视为一个样本,而在RetinaNet中,一张图片中的每一个Anchor才视为一个样本。根据不同的Anchor标签分配方式,目标检测器被划分为Anchor based型与Anchor free型。Anchor based型目标检测器的典型代表就是Faster rcnn和RetinaNet。Anchor free型目标检测器是最近这两年才开始加速发展起来的,比较有代表性的有FCOS(2019年)。
RetinaNet的标签分配规则和Faster rcnn基本一致,只是修改了IoU阈值。对于单张图片,首先计算这张图片的所有Anchor与这张图标注的所有objects的iou。对每个Anchor,先取IoU最大的object的回归标签作为其回归标签。然后,根据最大IoU的值进行class标签的分配。对于IoU小于0.4的Anchor,其标签置为0,代表负样本Anchor;对于最大IoU大于0.5的Anchor,其标签置为最大IoU对应的obejct的类标签+1(因为处理数据集时所有的object类index都是从0开始的,所以这里要+1),代表正样本Anchor。剩下的Anchor样本即IoU在0.4-0.5之间的Anchor,其类标签置为-1,代表被忽略的Anchor,这部分Anchor无论是在focal loss还是在smooth l1 loss中都不参与loss计算。
RetinaNet使用了Faster rcnn的Anchor分配规则,但是Faster rcnn有两条Anchor分配规则,而上面只是第二条Anchor分配规则,为什么上面没有体现第一条Anchor分配规则?
没错,上面的分配规则是Faster rcnn中的第二条Anchor分配规则,只是RetinaNet修改了分配正负样本的IoU阈值。Faster rcnn(https://arxiv.org/pdf/1506.01497.pdf)中的第一条Anchor分配规则是如果最大IoU也没有大于0.5,则这个最大IoU的Anchor也设为正样本。但是在遍历COCO数据集后发现,这种情况非常少见,因此我们不使用第一条Anchor分配规则。这样相当于这部分object没有用于训练,但由于数量很少,对模型的性能表现不会产生影响。
annotations中提供的是box坐标,但训练时使用的不是box坐标,这个是如何转换的呢?
Faster rcnn在回归时将box坐标先转换为tx,ty,tw,th,然后使用smooth l1 loss进行回归。需要注意的是,在faster rcnn实现中,smooth l1 loss中增加了一个beta值来放大或缩小loss。这个beta一般取经验值1/9,与原始公式取值1时相比,loss被放大了一些。另外需要说明的是,在许多faster rcnn的实现代码中,将box坐标按照faster rcnn中公式转换为tx,ty,tw,th后,这四个值又除以了[0.1,0.1,0.2,0.2]进一步放大。为此我专门做了不放大和放大后的对比实验,发现放大后模型收敛更快,性能表现也更好。
box坐标转换为回归标签tx,ty,tw,th的相关代码如下:
def snap_annotations_as_tx_ty_tw_th(self, anchors_gt_bboxes, anchors):
"""
snap each anchor ground truth bbox form format:[x_min,y_min,x_max,y_max] to format:[tx,ty,tw,th]
"""
anchors_w_h = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_w_h
anchors_gt_bboxes_w_h = anchors_gt_bboxes[:,
2:] - anchors_gt_bboxes[:, :2]
anchors_gt_bboxes_w_h = torch.clamp(anchors_gt_bboxes_w_h, min=1.0)
anchors_gt_bboxes_ctr = anchors_gt_bboxes[:, :
2] + 0.5 * anchors_gt_bboxes_w_h
snaped_annotations_for_anchors = torch.cat(
[(anchors_gt_bboxes_ctr - anchors_ctr) / anchors_w_h,
torch.log(anchors_gt_bboxes_w_h / anchors_w_h)],
axis=1)
device = snaped_annotations_for_anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
snaped_annotations_for_anchors = snaped_annotations_for_anchors / factor
# snaped_annotations_for_anchors shape:[batch_size, anchor_nums, 4]
return snaped_annotations_for_anchors
计算IoU的相关代码如下:
def compute_ious_for_one_image(self, one_image_anchors,
one_image_annotations):
"""
compute ious between one image anchors and one image annotations
"""
# make sure anchors format:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
# make sure annotations format: [annotation_nums,4],4:[x_min,y_min,x_max,y_max]
annotation_num = one_image_annotations.shape[0]
one_image_ious = []
for annotation_index in range(annotation_num):
annotation = one_image_annotations[
annotation_index:annotation_index + 1, :]
overlap_area_top_left = torch.max(one_image_anchors[:, :2],
annotation[:, :2])
overlap_area_bot_right = torch.min(one_image_anchors[:, 2:],
annotation[:, 2:])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
# anchors and annotations convert format to [x1,y1,w,h]
anchors_w_h = one_image_anchors[:,
2:] - one_image_anchors[:, :2] + 1
annotations_w_h = annotation[:, 2:] - annotation[:, :2] + 1
# compute anchors_area and annotations_area
anchors_area = anchors_w_h[:, 0] * anchors_w_h[:, 1]
annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]
# compute union_area
union_area = anchors_area + annotations_area - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious between one image anchors and one image annotations
ious = (overlap_area / union_area).unsqueeze(-1)
one_image_ious.append(ious)
one_image_ious = torch.cat(one_image_ious, axis=1)
# one image ious shape:[anchors_num,annotation_num]
return one_image_ious
Anchor标签分配的代码如下:
def get_batch_anchors_annotations(self, batch_anchors, annotations):
"""
Assign a ground truth box target and a ground truth class target for each anchor
if anchor gt_class index = -1,this anchor doesn't calculate cls loss and reg loss
if anchor gt_class index = 0,this anchor is a background class anchor and used in calculate cls loss
if anchor gt_class index > 0,this anchor is a object class anchor and used in
calculate cls loss and reg loss
"""
device = annotations.device
assert batch_anchors.shape[0] == annotations.shape[0]
one_image_anchor_nums = batch_anchors.shape[1]
batch_anchors_annotations = []
for one_image_anchors, one_image_annotations in zip(
batch_anchors, annotations):
# drop all index=-1 class annotations
one_image_annotations = one_image_annotations[
one_image_annotations[:, 4] >= 0]
if one_image_annotations.shape[0] == 0:
one_image_anchor_annotations = torch.ones(
[one_image_anchor_nums, 5], device=device) * (-1)
else:
one_image_gt_bboxes = one_image_annotations[:, 0:4]
one_image_gt_class = one_image_annotations[:, 4]
one_image_ious = self.compute_ious_for_one_image(
one_image_anchors, one_image_gt_bboxes)
# snap per gt bboxes to the best iou anchor
overlap, indices = one_image_ious.max(axis=1)
# assgin each anchor gt bboxes for max iou annotation
per_image_anchors_gt_bboxes = one_image_gt_bboxes[indices]
# transform gt bboxes to [tx,ty,tw,th] format for each anchor
one_image_anchors_snaped_boxes = self.snap_annotations_as_tx_ty_tw_th(
per_image_anchors_gt_bboxes, one_image_anchors)
one_image_anchors_gt_class = (torch.ones_like(overlap) *
-1).to(device)
# if iou <0.4,assign anchors gt class as 0:background
one_image_anchors_gt_class[overlap < 0.4] = 0
# if iou >=0.5,assign anchors gt class as same as the max iou annotation class:80 classes index from 1 to 80
one_image_anchors_gt_class[
overlap >=
0.5] = one_image_gt_class[indices][overlap >= 0.5] + 1
one_image_anchors_gt_class = one_image_anchors_gt_class.unsqueeze(
-1)
one_image_anchor_annotations = torch.cat([
one_image_anchors_snaped_boxes, one_image_anchors_gt_class
],
axis=1)
one_image_anchor_annotations = one_image_anchor_annotations.unsqueeze(
0)
batch_anchors_annotations.append(one_image_anchor_annotations)
batch_anchors_annotations = torch.cat(batch_anchors_annotations,
axis=0)
# batch anchors annotations shape:[batch_size, anchor_nums, 5]
return batch_anchors_annotations
RetinaNet训练时包含focal loss(分类)和smooth l1 loss(回归)。
对于focal loss,我们计算时过滤掉类index为-1的Anchor样本,只使用正样本Anchor和负样本Anchor进行计算(必须要同时有正样本和负样本,否则这张图片不计算focal loss和smooth l1 loss)。focal loss实际上是一个80个二分类的bce loss,只是使用了alpha和gamma来分别调整loss中类别的不平衡和样本学习难易程度的不平衡。前面所说的正样本即在80个类别中某个类别的one hot向量值为1的样本,而负样本即在80个类别中所有类别的one hot向量均为0的样本。最后,根据RetinaNet论文中所述,由于使用了alpha和gamma,最后容易学习的负样本的loss值占总loss值的比例并不大,因此focal最后求和后只除以正样本的数量即可。
focal loss代码实现如下:
def compute_one_image_focal_loss(self, per_image_cls_heads,
per_image_anchors_annotations):
"""
compute one image focal loss(cls loss)
per_image_cls_heads:[anchor_num,num_classes]
per_image_anchors_annotations:[anchor_num,5]
"""
# Filter anchors with gt class=-1, this part of anchor doesn't calculate focal loss
per_image_cls_heads = per_image_cls_heads[
per_image_anchors_annotations[:, 4] >= 0]
per_image_anchors_annotations = per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] >= 0]
per_image_cls_heads = torch.clamp(per_image_cls_heads,
min=self.epsilon,
max=1. - self.epsilon)
num_classes = per_image_cls_heads.shape[1]
# generate 80 binary ground truth classes for each anchor
loss_ground_truth = F.one_hot(per_image_anchors_annotations[:,
4].long(),
num_classes=num_classes + 1)
loss_ground_truth = loss_ground_truth[:, 1:]
loss_ground_truth = loss_ground_truth.float()
alpha_factor = torch.ones_like(per_image_cls_heads) * self.alpha
alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
alpha_factor, 1. - alpha_factor)
pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_heads,
1. - per_image_cls_heads)
focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
bce_loss = -(
loss_ground_truth * torch.log(per_image_cls_heads) +
(1. - loss_ground_truth) * torch.log(1. - per_image_cls_heads))
one_image_focal_loss = focal_weight * bce_loss
one_image_focal_loss = one_image_focal_loss.sum()
positive_anchors_num = per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] > 0].shape[0]
# according to the original paper,We divide the focal loss by the number of positive sample anchors
one_image_focal_loss = one_image_focal_loss / positive_anchors_num
return one_image_focal_loss
对于smooth l1 loss,我们遵循RetinaNet论文中所述,只使用正样本进行loss计算,最后也除以正样本数量。但是实践中发现这样smooth l1 loss要比focal loss大4倍,因此先取tx,ty,tw,th四个位置的均值后再求和所有样本loss,然后除以正样本数量。
smooth l1 loss代码实现如下:
def compute_one_image_smoothl1_loss(self, per_image_reg_heads,
per_image_anchors_annotations):
"""
compute one image smoothl1 loss(reg loss)
per_image_reg_heads:[anchor_num,4]
per_image_anchors_annotations:[anchor_num,5]
"""
# Filter anchors with gt class=-1, this part of anchor doesn't calculate smoothl1 loss
device = per_image_reg_heads.device
per_image_reg_heads = per_image_reg_heads[
per_image_anchors_annotations[:, 4] > 0]
per_image_anchors_annotations = per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] > 0]
positive_anchor_num = per_image_anchors_annotations.shape[0]
if positive_anchor_num == 0:
return torch.tensor(0.).to(device)
# compute smoothl1 loss
loss_ground_truth = per_image_anchors_annotations[:, 0:4]
x = torch.abs(per_image_reg_heads - loss_ground_truth)
one_image_smoothl1_loss = torch.where(torch.ge(x, self.beta),
x - 0.5 * self.beta,
0.5 * (x**2) / self.beta)
one_image_smoothl1_loss = one_image_smoothl1_loss.mean(axis=1).sum()
# according to the original paper,We divide the smoothl1 loss by the number of positive sample anchors
one_image_smoothl1_loss = one_image_smoothl1_loss / positive_anchor_num
return one_image_smoothl1_loss
在loss计算前,我们遵循和faster rcnn一样的做法,先去除掉所有超出图片边界的Anchor,这部分Anchor不用于loss计算。此外,如果一张图片上没有object,那么Anchor中就不会有正样本,我们就直接把这张图片的focal loss和smooth l1 loss值设为0。
总的RetinaNet loss代码实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class RetinaLoss(nn.Module):
def __init__(self,
image_w,
image_h,
alpha=0.25,
gamma=2,
beta=1.0 / 9.0,
epsilon=1e-4):
super(RetinaLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.beta = beta
self.epsilon = epsilon
self.image_w = image_w
self.image_h = image_h
def forward(self, cls_heads, reg_heads, batch_anchors, annotations):
"""
compute cls loss and reg loss in one batch
"""
device = annotations.device
cls_heads = torch.cat(cls_heads, axis=1)
reg_heads = torch.cat(reg_heads, axis=1)
batch_anchors = torch.cat(batch_anchors, axis=1)
cls_heads, reg_heads, batch_anchors = self.drop_out_border_anchors_and_heads(
cls_heads, reg_heads, batch_anchors, self.image_w, self.image_h)
batch_anchors_annotations = self.get_batch_anchors_annotations(
batch_anchors, annotations)
cls_loss, reg_loss = [], []
valid_image_num = 0
for per_image_cls_heads, per_image_reg_heads, per_image_anchors_annotations in zip(
cls_heads, reg_heads, batch_anchors_annotations):
# valid anchors contain all positive anchors
valid_anchors_num = (per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] > 0]).shape[0]
if valid_anchors_num == 0:
cls_loss.append(torch.tensor(0.).to(device))
reg_loss.append(torch.tensor(0.).to(device))
else:
valid_image_num += 1
one_image_cls_loss = self.compute_one_image_focal_loss(
per_image_cls_heads, per_image_anchors_annotations)
one_image_reg_loss = self.compute_one_image_smoothl1_loss(
per_image_reg_heads, per_image_anchors_annotations)
cls_loss.append(one_image_cls_loss)
reg_loss.append(one_image_reg_loss)
cls_loss = sum(cls_loss) / valid_image_num
reg_loss = sum(reg_loss) / valid_image_num
return cls_loss, reg_loss
def compute_one_image_focal_loss(self, per_image_cls_heads,
per_image_anchors_annotations):
"""
compute one image focal loss(cls loss)
per_image_cls_heads:[anchor_num,num_classes]
per_image_anchors_annotations:[anchor_num,5]
"""
# Filter anchors with gt class=-1, this part of anchor doesn't calculate focal loss
per_image_cls_heads = per_image_cls_heads[
per_image_anchors_annotations[:, 4] >= 0]
per_image_anchors_annotations = per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] >= 0]
per_image_cls_heads = torch.clamp(per_image_cls_heads,
min=self.epsilon,
max=1. - self.epsilon)
num_classes = per_image_cls_heads.shape[1]
# generate 80 binary ground truth classes for each anchor
loss_ground_truth = F.one_hot(per_image_anchors_annotations[:,
4].long(),
num_classes=num_classes + 1)
loss_ground_truth = loss_ground_truth[:, 1:]
loss_ground_truth = loss_ground_truth.float()
alpha_factor = torch.ones_like(per_image_cls_heads) * self.alpha
alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
alpha_factor, 1. - alpha_factor)
pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_heads,
1. - per_image_cls_heads)
focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
bce_loss = -(
loss_ground_truth * torch.log(per_image_cls_heads) +
(1. - loss_ground_truth) * torch.log(1. - per_image_cls_heads))
one_image_focal_loss = focal_weight * bce_loss
one_image_focal_loss = one_image_focal_loss.sum()
positive_anchors_num = per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] > 0].shape[0]
# according to the original paper,We divide the focal loss by the number of positive sample anchors
one_image_focal_loss = one_image_focal_loss / positive_anchors_num
return one_image_focal_loss
def compute_one_image_smoothl1_loss(self, per_image_reg_heads,
per_image_anchors_annotations):
"""
compute one image smoothl1 loss(reg loss)
per_image_reg_heads:[anchor_num,4]
per_image_anchors_annotations:[anchor_num,5]
"""
# Filter anchors with gt class=-1, this part of anchor doesn't calculate smoothl1 loss
device = per_image_reg_heads.device
per_image_reg_heads = per_image_reg_heads[
per_image_anchors_annotations[:, 4] > 0]
per_image_anchors_annotations = per_image_anchors_annotations[
per_image_anchors_annotations[:, 4] > 0]
positive_anchor_num = per_image_anchors_annotations.shape[0]
if positive_anchor_num == 0:
return torch.tensor(0.).to(device)
# compute smoothl1 loss
loss_ground_truth = per_image_anchors_annotations[:, 0:4]
x = torch.abs(per_image_reg_heads - loss_ground_truth)
one_image_smoothl1_loss = torch.where(torch.ge(x, self.beta),
x - 0.5 * self.beta,
0.5 * (x**2) / self.beta)
one_image_smoothl1_loss = one_image_smoothl1_loss.mean(axis=1).sum()
# according to the original paper,We divide the smoothl1 loss by the number of positive sample anchors
one_image_smoothl1_loss = one_image_smoothl1_loss / positive_anchor_num
return one_image_smoothl1_loss
def drop_out_border_anchors_and_heads(self, cls_heads, reg_heads,
batch_anchors, image_w, image_h):
"""
dropout out of border anchors,cls heads and reg heads
"""
final_cls_heads, final_reg_heads, final_batch_anchors = [], [], []
for per_image_cls_head, per_image_reg_head, per_image_anchors in zip(
cls_heads, reg_heads, batch_anchors):
per_image_cls_head = per_image_cls_head[per_image_anchors[:,
0] > 0.0]
per_image_reg_head = per_image_reg_head[per_image_anchors[:,
0] > 0.0]
per_image_anchors = per_image_anchors[per_image_anchors[:,
0] > 0.0]
per_image_cls_head = per_image_cls_head[per_image_anchors[:,
1] > 0.0]
per_image_reg_head = per_image_reg_head[per_image_anchors[:,
1] > 0.0]
per_image_anchors = per_image_anchors[per_image_anchors[:,
1] > 0.0]
per_image_cls_head = per_image_cls_head[
per_image_anchors[:, 2] < image_w]
per_image_reg_head = per_image_reg_head[
per_image_anchors[:, 2] < image_w]
per_image_anchors = per_image_anchors[
per_image_anchors[:, 2] < image_w]
per_image_cls_head = per_image_cls_head[
per_image_anchors[:, 3] < image_h]
per_image_reg_head = per_image_reg_head[
per_image_anchors[:, 3] < image_h]
per_image_anchors = per_image_anchors[
per_image_anchors[:, 3] < image_h]
per_image_cls_head = per_image_cls_head.unsqueeze(0)
per_image_reg_head = per_image_reg_head.unsqueeze(0)
per_image_anchors = per_image_anchors.unsqueeze(0)
final_cls_heads.append(per_image_cls_head)
final_reg_heads.append(per_image_reg_head)
final_batch_anchors.append(per_image_anchors)
final_cls_heads = torch.cat(final_cls_heads, axis=0)
final_reg_heads = torch.cat(final_reg_heads, axis=0)
final_batch_anchors = torch.cat(final_batch_anchors, axis=0)
# final cls heads shape:[batch_size, anchor_nums, class_num]
# final reg heads shape:[batch_size, anchor_nums, 4]
# final batch anchors shape:[batch_size, anchor_nums, 4]
return final_cls_heads, final_reg_heads, final_batch_anchors
def get_batch_anchors_annotations(self, batch_anchors, annotations):
"""
Assign a ground truth box target and a ground truth class target for each anchor
if anchor gt_class index = -1,this anchor doesn't calculate cls loss and reg loss
if anchor gt_class index = 0,this anchor is a background class anchor and used in calculate cls loss
if anchor gt_class index > 0,this anchor is a object class anchor and used in
calculate cls loss and reg loss
"""
device = annotations.device
assert batch_anchors.shape[0] == annotations.shape[0]
one_image_anchor_nums = batch_anchors.shape[1]
batch_anchors_annotations = []
for one_image_anchors, one_image_annotations in zip(
batch_anchors, annotations):
# drop all index=-1 class annotations
one_image_annotations = one_image_annotations[
one_image_annotations[:, 4] >= 0]
if one_image_annotations.shape[0] == 0:
one_image_anchor_annotations = torch.ones(
[one_image_anchor_nums, 5], device=device) * (-1)
else:
one_image_gt_bboxes = one_image_annotations[:, 0:4]
one_image_gt_class = one_image_annotations[:, 4]
one_image_ious = self.compute_ious_for_one_image(
one_image_anchors, one_image_gt_bboxes)
# snap per gt bboxes to the best iou anchor
overlap, indices = one_image_ious.max(axis=1)
# assgin each anchor gt bboxes for max iou annotation
per_image_anchors_gt_bboxes = one_image_gt_bboxes[indices]
# transform gt bboxes to [tx,ty,tw,th] format for each anchor
one_image_anchors_snaped_boxes = self.snap_annotations_as_tx_ty_tw_th(
per_image_anchors_gt_bboxes, one_image_anchors)
one_image_anchors_gt_class = (torch.ones_like(overlap) *
-1).to(device)
# if iou <0.4,assign anchors gt class as 0:background
one_image_anchors_gt_class[overlap < 0.4] = 0
# if iou >=0.5,assign anchors gt class as same as the max iou annotation class:80 classes index from 1 to 80
one_image_anchors_gt_class[
overlap >=
0.5] = one_image_gt_class[indices][overlap >= 0.5] + 1
one_image_anchors_gt_class = one_image_anchors_gt_class.unsqueeze(
-1)
one_image_anchor_annotations = torch.cat([
one_image_anchors_snaped_boxes, one_image_anchors_gt_class
],
axis=1)
one_image_anchor_annotations = one_image_anchor_annotations.unsqueeze(
0)
batch_anchors_annotations.append(one_image_anchor_annotations)
batch_anchors_annotations = torch.cat(batch_anchors_annotations,
axis=0)
# batch anchors annotations shape:[batch_size, anchor_nums, 5]
return batch_anchors_annotations
def snap_annotations_as_tx_ty_tw_th(self, anchors_gt_bboxes, anchors):
"""
snap each anchor ground truth bbox form format:[x_min,y_min,x_max,y_max] to format:[tx,ty,tw,th]
"""
anchors_w_h = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_w_h
anchors_gt_bboxes_w_h = anchors_gt_bboxes[:,
2:] - anchors_gt_bboxes[:, :2]
anchors_gt_bboxes_w_h = torch.clamp(anchors_gt_bboxes_w_h, min=1.0)
anchors_gt_bboxes_ctr = anchors_gt_bboxes[:, :
2] + 0.5 * anchors_gt_bboxes_w_h
snaped_annotations_for_anchors = torch.cat(
[(anchors_gt_bboxes_ctr - anchors_ctr) / anchors_w_h,
torch.log(anchors_gt_bboxes_w_h / anchors_w_h)],
axis=1)
device = snaped_annotations_for_anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
snaped_annotations_for_anchors = snaped_annotations_for_anchors / factor
# snaped_annotations_for_anchors shape:[batch_size, anchor_nums, 4]
return snaped_annotations_for_anchors
def compute_ious_for_one_image(self, one_image_anchors,
one_image_annotations):
"""
compute ious between one image anchors and one image annotations
"""
# make sure anchors format:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
# make sure annotations format: [annotation_nums,4],4:[x_min,y_min,x_max,y_max]
annotation_num = one_image_annotations.shape[0]
one_image_ious = []
for annotation_index in range(annotation_num):
annotation = one_image_annotations[
annotation_index:annotation_index + 1, :]
overlap_area_top_left = torch.max(one_image_anchors[:, :2],
annotation[:, :2])
overlap_area_bot_right = torch.min(one_image_anchors[:, 2:],
annotation[:, 2:])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
# anchors and annotations convert format to [x1,y1,w,h]
anchors_w_h = one_image_anchors[:,
2:] - one_image_anchors[:, :2] + 1
annotations_w_h = annotation[:, 2:] - annotation[:, :2] + 1
# compute anchors_area and annotations_area
anchors_area = anchors_w_h[:, 0] * anchors_w_h[:, 1]
annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]
# compute union_area
union_area = anchors_area + annotations_area - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious between one image anchors and one image annotations
ious = (overlap_area / union_area).unsqueeze(-1)
one_image_ious.append(ious)
one_image_ious = torch.cat(one_image_ious, axis=1)
# one image ious shape:[anchors_num,annotation_num]
return one_image_ious
这样RetinaNet的loss就实现好了。