YOLOX简洁且高效,分享具体实现过程。部分代码可以迁移,很具有参考价值。
测试比较简单,首先看demo.py。
-运行需要指定三个参数:
–path:测试图片路径
–exp_file:指定使用模型配置文件,如default/yolox_m.py
–ckpt:预训练权重,如yolox_m.pth
outputs, img_info = predictor.inference(image_name) # output:(14,7):x1,y1,x2,y2,conf,conf,class
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
img = cv2.imread(img)
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
# 这里是对原图做比例缩放,至640*640
img, _ = self.preproc(img, None, self.test_size) # 转为(3,640,640)
with torch.no_grad():
outputs = self.model(img) # ([1, 8400, 85]):8400 = 80*80 +40*40 +20*20; 85 = 80+4+1
outputs = postprocess(
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
fpn_outs = self.backbone(x)
# (128, 80, 80]) (256, 40, 40) (512, 20, 20) 下采样的三个特征图
outputs = self.head(fpn_outs)
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)
): # 循环3次,每次对一个特征图进行分类和回归
x = self.stems[k](x) # 将特征图维度变换至128,如特征1:(1,128,80,80)
cls_x = x
reg_x = x
cls_feat = cls_conv(cls_x) # 这里是解藕头,连续两个conv(128,128,3,1)+bn+SiLU
cls_output = self.cls_preds[k](cls_feat) # Conv2d(128, 20),分类
reg_feat = reg_conv(reg_x) # 解藕头,同上
reg_output = self.reg_preds[k](reg_feat) # Conv2d(128, 4),回归
obj_output = self.obj_preds[k](reg_feat) # Conv2d(128, 1),目标预测
output = torch.cat(
[reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
) # (1,25,80,80)
outputs.append(output) # (1,25,80,80) (1,25,40,40) (1,25,20,20)
self.hw = [x.shape[-2:] for x in outputs] # torch.Size(80, 80)(40, 40), (20, 20)
outputs = torch.cat(
[x.flatten(start_dim=2) for x in outputs], dim=2
).permute(0, 2, 1) # ([1, 8400, 25])
if self.decode_in_inference: # True
return self.decode_outputs(outputs, dtype=xin[0].type())
else:
return outputs
def decode_outputs(self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides): # 80,40,20,对应下采样[8, 16, 32]
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
# 以(80,80)特征图为例,生成两个(80,80)坐标点
grid = torch.stack((xv, yv), 2).view(1, -1, 2) # ([1, 6400, 2])
grids.append(grid)
shape = grid.shape[:2] # ([1, 6400])
strides.append(torch.full((*shape, 1), stride)) # (1,6400,1)*[8] (1,1600,1)*[16] (1,400,1)*[32]
grids = torch.cat(grids, dim=1).type(dtype)
strides = torch.cat(strides, dim=1).type(dtype)
outputs[..., :2] = (outputs[..., :2] + grids) * strides # (预测x、y+anchor中心点坐标)*下采样倍数
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides # (预测w、h)*下采样倍数
return outputs # ([1, 8400, 85]):8400 = 80*80 +40*40 +20*20; 85 = 80+4+1
outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre, class_agnostic=True):
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
box_corner = prediction.new(prediction.shape)
## 转为左上角与右下角坐标:x1 y1 x2 y2
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction): # image_pred:(8400, 85)
# If none are remaining => process next image
if not image_pred.size(0):
continue
# Get score and class with highest confidence
class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) # 类别分数*置信度,用0.3筛选
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) # (8400, 7)
detections = detections[conf_mask] # (93, 7) 根据0.3置信度筛选后
if class_agnostic:
nms_out_index = torchvision.ops.nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
nms_thre,
) # NMS(根据分数和位置):返回剩余目标的index
else:
nms_out_index = torchvision.ops.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thre,
) # 未执行
detections = detections[nms_out_index] # (14,7)
if output[i] is None:
output[i] = detections
else:
output[i] = torch.cat((output[i], detections))
return output
outputs, img_info = predictor.inference(image_name)
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
if save_result:
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)
ch = cv2.waitKey(0)
def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"] # 缩放比例:0.45
img = img_info["raw_img"] # (1050, 1400, 3)
if output is None:
return img
output = output.cpu()
bboxes = output[:, 0:4]
# preprocessing: resize
bboxes /= ratio
cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res
def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
for i in range(len(boxes)):
box = boxes[i]
cls_id = int(cls_ids[i])
score = scores[i]
if score < conf:
continue
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
font = cv2.FONT_HERSHEY_SIMPLEX
txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
cv2.rectangle(
img,
(x0, y0 + 1),
(x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
txt_bk_color,
-1
)
cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
return img
训练阶段数据格式:在datadets/VOCdevkit/VOC2007/文件夹中存放三个文件夹,分别为:JPEGImages(若干张jpg图像)Annotations(对应的若干个xml标注)ImageSets文件夹。
训练从train.py第line 110进入trainer.train()
yolox.py line30:
fpn_outs = self.backbone(x)
if self.training:
assert targets is not None
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
fpn_outs, targets, x
)
outputs = {
"total_loss": loss,
"iou_loss": iou_loss,
"l1_loss": l1_loss,
"conf_loss": conf_loss,
"cls_loss": cls_loss,
"num_fg": num_fg,
}
else:
outputs = self.head(fpn_outs) # Iou损失、类别与置信度损失
return outputs
主要函数是 self.get_assignments,用来分配正标签,下面会给出具体分析
以及其中的self.dynamic_k_matching函数,动态获得k个正样本
class YOLOXHead(nn.Module):
def get_losses(self,imgs, x_shifts, y_shifts, expanded_strides, labels, outputs,
origin_preds, dtype):
bbox_preds = outputs[:, :, :4] # [bs, n_anchors, 4]:([8, 8400, 4])
obj_preds = outputs[:, :, 4].unsqueeze(-1) # ([8, 8400, 1])
cls_preds = outputs[:, :, 5:] # ([8, 8400, 20])
# calculate targets
nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # gt_num:[ 5, 6, 21, 2, 5, 2, 2, 6]
total_num_anchors = outputs.shape[1] # 8400
x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] x_shifts[0]:(1, 6400) x_shifts[1]:(1, 1600) x_shifts[2]:(1, 400) [0,1,2,...19,0,1,2...]
y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] ([1, 8400])
expanded_strides = torch.cat(expanded_strides, 1) # (1,8400): 6400*[8,8,8...] 1600*[16,16,16...] 400*[32,32,32,...]
if self.use_l1:
origin_preds = torch.cat(origin_preds, 1)
cls_targets = []
reg_targets = []
l1_targets = []
obj_targets = []
fg_masks = []
num_fg = 0.0
num_gts = 0.0
for batch_idx in range(outputs.shape[0]): # batchsize
num_gt = int(nlabel[batch_idx])
num_gts += num_gt # 5
if num_gt == 0:
cls_target = outputs.new_zeros((0, self.num_classes))
reg_target = outputs.new_zeros((0, 4))
l1_target = outputs.new_zeros((0, 4))
obj_target = outputs.new_zeros((total_num_anchors, 1))
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] # (8,4)
gt_classes = labels[batch_idx, :num_gt, 0] # (8) gt_num
bboxes_preds_per_image = bbox_preds[batch_idx] # (8400,4)
try:
( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img,
) = self.get_assignments( batch_idx, num_gt, total_num_anchors,
gt_bboxes_per_image, gt_classes, bboxes_preds_per_image,
expanded_strides, x_shifts, y_shifts, cls_preds, bbox_preds,
obj_preds, labels, imgs)
# 以上函数:分配正负样本。返回值可查看 3.1节self.get_assignments 最后结果
torch.cuda.empty_cache()
num_fg += num_fg_img # 34
cls_target = F.one_hot(
gt_matched_classes.to(torch.int64), self.num_classes
) * pred_ious_this_matching.unsqueeze(-1) # (34) --> ( 34,20 ) *iou_score
obj_target = fg_mask.unsqueeze(-1) # ( 8400,1 ) :34*True
reg_target = gt_bboxes_per_image[matched_gt_inds] # ( 34,4 )
cls_targets.append(cls_target)
reg_targets.append(reg_target)
obj_targets.append(obj_target.to(dtype))
fg_masks.append(fg_mask)
if self.use_l1: # False
l1_targets.append(l1_target)
cls_targets = torch.cat(cls_targets, 0) # ( 385,20 )
reg_targets = torch.cat(reg_targets, 0) # ( 385,4 )
obj_targets = torch.cat(obj_targets, 0) # ( 67200,1 ) 8400*8 = 67200
fg_masks = torch.cat(fg_masks, 0) # ( 67200 )
if self.use_l1:
l1_targets = torch.cat(l1_targets, 0)
num_fg = max(num_fg, 1)
loss_iou = (
self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
).sum() / num_fg
loss_obj = (
self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
).sum() / num_fg
loss_cls = (
self.bcewithlog_loss(
cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
)
).sum() / num_fg
if self.use_l1:
loss_l1 = (
self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)
).sum() / num_fg
else:
loss_l1 = 0.0
reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
return (
loss,
reg_weight * loss_iou,
loss_obj,
loss_cls,
loss_l1,
num_fg / max(num_gts, 1),
)
这里是把标签gt分配到三张特征图上(共8400个点),并作出正负样本分类。
def get_assignments( self, batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image,
gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
cls_preds, bbox_preds, obj_preds, labels, imgs, mode="gpu"):
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
gt_bboxes_per_image, expanded_strides, x_shifts,
y_shifts, total_num_anchors, num_gt) # (8400) : 3473*[True] # (5, 3473) :325*[True]
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] # ([8400, 4]) ---> ([3473, 4])
cls_preds_ = cls_preds[batch_idx][fg_mask] # ([3473, 20])
obj_preds_ = obj_preds[batch_idx][fg_mask] # ([3473, 1])
num_in_boxes_anchor = bboxes_preds_per_image.shape[0] # 3473
pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False # (5,4) & (3473,4) --> (5, 3473)
gt_cls_per_image = (
F.one_hot(gt_classes.to(torch.int64), self.num_classes)
.float() .unsqueeze(1) .repeat(1, num_in_boxes_anchor, 1)) # (5,1) --> (5,20) --> (5,3473,20)
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) # (5, 3473)
with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = (
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
* obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
) # ( 3473, 20 ) --> sigmoid --> ( 5, 3473, 20 )
pair_wise_cls_loss = F.binary_cross_entropy(
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
).sum(-1) # ( 5, 3473, 20 ) & ( 5, 3473, 20 ) ---> ( 5,3473 )
del cls_preds_
cost = (
pair_wise_cls_loss
+ 3.0 * pair_wise_ious_loss
+ 100000.0 * (~is_in_boxes_and_center)
) # ( 5, 3473 )
(
num_fg,
gt_matched_classes,
pred_ious_this_matching,
matched_gt_inds,
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
return (
gt_matched_classes, #(34)34个正样本的类别
fg_mask, #(8400)中有34个True
pred_ious_this_matching, #(34)34个正样本的IOU
matched_gt_inds, # (34) 34个正样本,跟第几个gt更匹配
num_fg,
)
对预测的8400个目标作初步筛选
根据anchor中心点与gt左上右下的偏移值,筛选出偏移大于0的结果(计算b_l, b_t, b_r, b_b的位置)(c_l, c_t, c_r, c_b也是同理)
def get_in_boxes_info(
self, gt_bboxes_per_image, expanded_strides, x_shifts,
y_shifts, total_num_anchors, num_gt):
expanded_strides_per_image = expanded_strides[0] # (8400)
x_shifts_per_image = x_shifts[0] * expanded_strides_per_image # (8400) [0,1,2...79,...0,1,2,...39,0,1,2,...19]*stride
y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
x_centers_per_image = (
(x_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1) # (5,8400) 8400个中心点坐标(640*640图像上的绝对值)
) # [n_anchor] -> [n_gt, n_anchor]
y_centers_per_image = (
(y_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1)
)
gt_bboxes_per_image_l = (
(gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
.unsqueeze(1)
.repeat(1, total_num_anchors)
) # ([5, 8400]) x1
gt_bboxes_per_image_r = (
(gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
.unsqueeze(1)
.repeat(1, total_num_anchors)
) # ([5, 8400]) x2
gt_bboxes_per_image_t = (
(gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
.unsqueeze(1)
.repeat(1, total_num_anchors)
) # ([5, 8400]) y1
gt_bboxes_per_image_b = (
(gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
.unsqueeze(1)
.repeat(1, total_num_anchors)
) # ([5, 8400]) y2
b_l = x_centers_per_image - gt_bboxes_per_image_l # ([5, 8400])
b_r = gt_bboxes_per_image_r - x_centers_per_image
b_t = y_centers_per_image - gt_bboxes_per_image_t
b_b = gt_bboxes_per_image_b - y_centers_per_image
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) # ([5, 8400, 4]) gt与anchor中心点的四个偏移值
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 # ([5, 8400])
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
# in fixed center
center_radius = 2.5
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
1, total_num_anchors # (5,1) ->(5.8400)
) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
1, total_num_anchors
) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
1, total_num_anchors
) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
1, total_num_anchors
) + center_radius * expanded_strides_per_image.unsqueeze(0)
c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) # ([5, 8400, 4])
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
# in boxes and in centers
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all # (8400) : 3473*[True]
is_in_boxes_and_center = (
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] # ([5, 3473]) :325*[True]
)
return is_in_boxes_anchor, is_in_boxes_and_center
根据iou动态选择k个样本
例如:给5个gt分配了34个样本,并返回这34个样本的最大iou分数(pred_ious_this_matching)
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
# Dynamic K
# ---------------------------------------------------------------
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # ([5, 3473])
ious_in_boxes_matrix = pair_wise_ious
n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) # 10
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) # ( 5, 10 )
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
dynamic_ks = dynamic_ks.tolist() # [3, 7, 9, 9, 6]
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
) # ([3473])中取前3个 pos_idx: [ 3236, 3235, 3237 ]
matching_matrix[gt_idx][pos_idx] = 1 # 全0矩阵matching_matrix([5, 3473])的每行(每个gt)中,分别有 [3, 7, 9, 9, 6]个是1
del topk_ious, dynamic_ks, pos_idx
anchor_matching_gt = matching_matrix.sum(0) # ( 3473 )
if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
fg_mask_inboxes = matching_matrix.sum(0) > 0 # ( 3473 ) 34*[ True ]
num_fg = fg_mask_inboxes.sum().item() # 34
fg_mask[fg_mask.clone()] = fg_mask_inboxes
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) # ([5, 3473]) --> ([5, 34]).argmax --> (34)
# [4, 4, 2, 4, 4, 4, 3, 3, 3, 3, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 3, 2, 4, 3, 3, 2, 2, 1, 3, 1, 3, 1]
gt_matched_classes = gt_classes[matched_gt_inds]
# ( 34 ): [ 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 8., 8., 8., 8., 11., 11., 11., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 8., 14., 8., 14., 8. ]
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
fg_mask_inboxes
] # ( 34 ) scoers
return
num_fg, # 34
gt_matched_classes, #(34)34个正样本的类别
pred_ious_this_matching, #(34)34个正样本的IOU
matched_gt_inds # (34) 34个正样本,跟第几个gt更匹配
fg_mask # (8400)中有34个True
outputs = self.model(inps, targets)
loss = outputs["total_loss"]
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()