最近看了RePpoints的论文,然后想看看官方代码,以更好的理解论文,但是发现网上竟然没有一篇关于相关代码的解析,有点好奇为什么。官方代码是在mmdetecion框架上实现的,顺便也可以学习一下。因为自己也是小白,只做了其中部分注释(也只能看懂这些了),所以有大佬愿意分享自己见解的话,我举双手欢迎。下面就是 官方代码中head部分的代码:
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (PointGenerator, multi_apply, multiclass_nms)
from src.reppoints_generator.point_target import point_target
from mmdet.ops import DeformConv
from mmdet.models.builder import build_loss,HEADS
#from ..registry import HEADS
from mmcv.cnn import ConvModule, bias_init_with_prob
@HEADS.register_module
class RepPointsHead(nn.Module):
"""RepPoint head.
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
point_feat_channels (int): Number of channels of points features.
stacked_convs (int): How many conv layers are used.
gradient_mul (float): The multiplier to gradients from
points refinement and recognition.
point_strides (Iterable): points strides.
point_base_scale (int): bbox scale for assigning labels.
loss_cls (dict): Config of classification loss.
loss_bbox_init (dict): Config of initial points loss.
loss_bbox_refine (dict): Config of points loss in refinement.
use_grid_points (bool): If we use bounding box representation, the
reppoints is represented as grid points on the bounding box.
center_init (bool): Whether to use center point assignment.
transform_method (str): The methods to transform RepPoints to bbox.
""" # noqa: W605
#初始化参数
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
point_feat_channels=256,
stacked_convs=3,
num_points=9,
gradient_mul=0.1,
point_strides=[8, 16, 32, 64, 128],
point_base_scale=4,
conv_cfg=None,
norm_cfg=None,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_init=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
loss_bbox_refine=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
use_grid_points=False,
center_init=True,
transform_method='moment',
moment_mul=0.01):
super(RepPointsHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.point_feat_channels = point_feat_channels
self.stacked_convs = stacked_convs
self.num_points = num_points
self.gradient_mul = gradient_mul
self.point_base_scale = point_base_scale
self.point_strides = point_strides
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss']
self.loss_cls = build_loss(loss_cls)
self.loss_bbox_init = build_loss(loss_bbox_init)
self.loss_bbox_refine = build_loss(loss_bbox_refine)
self.use_grid_points = use_grid_points
self.center_init = center_init
self.transform_method = transform_method
if self.transform_method == 'moment':
self.moment_transfer = nn.Parameter(
data=torch.zeros(2), requires_grad=True)
self.moment_mul = moment_mul
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
self.point_generators = [PointGenerator() for _ in self.point_strides]
# we use deformable conv to extract points features
self.dcn_kernel = int(np.sqrt(num_points))
self.dcn_pad = int((self.dcn_kernel - 1) / 2)
assert self.dcn_kernel * self.dcn_kernel == num_points, \
"The points number should be a square number."
assert self.dcn_kernel % 2 == 1, \
"The points number should be an odd square number."
dcn_base = np.arange(-self.dcn_pad,
self.dcn_pad + 1).astype(np.float64)
dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
(-1))
self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
self._init_layers()
#初始化 head 网络的 一些卷积层
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
#这块应该是循环三次,构建前三层卷积层
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
self.reppoints_cls_conv = DeformConv(self.feat_channels,
self.point_feat_channels,
self.dcn_kernel, 1, self.dcn_pad)
self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
self.cls_out_channels, 1, 1, 0)
self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
self.point_feat_channels, 3,
1, 1)
self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
self.reppoints_pts_refine_conv = DeformConv(self.feat_channels,
self.point_feat_channels,
self.dcn_kernel, 1,
self.dcn_pad)
self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
#初始化权重
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.reppoints_cls_conv, std=0.01)
normal_init(self.reppoints_cls_out, std=0.01, bias=bias_cls)
normal_init(self.reppoints_pts_init_conv, std=0.01)
normal_init(self.reppoints_pts_init_out, std=0.01)
normal_init(self.reppoints_pts_refine_conv, std=0.01)
normal_init(self.reppoints_pts_refine_out, std=0.01)
# 将reppoints 的point set 转换为 bbox
# 一共三种方法: minmax ; partial minmax ; moment-based(均值和标准差)
def points2bbox(self, pts, y_first=True):
#输入参数pts: 是一个2N的标量,N是点的个数,2是x,y两个坐标
#输入参数y_first 是指 参数pts的表示形式,x在前还是y在前
"""
Converting the points set into bounding box.
:param pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
:param y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
:return: each points set is converting to a bbox [x1, y1, x2, y2].
"""
#将pts坐标 拆分成x坐标 和 y坐标
pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
...]
pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
...]
#minmax
if self.transform_method == 'minmax':
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
#partial_minmax
elif self.transform_method == 'partial_minmax':
#在 dim=1 取1到4 部分点(一共9个)
pts_y = pts_y[:, :4, ...] #三个点是省略后面所有的冒号
pts_x = pts_x[:, :4, ...]
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'moment':
#取均值和方差
pts_y_mean = pts_y.mean(dim=1, keepdim=True)
pts_x_mean = pts_x.mean(dim=1, keepdim=True)
pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
#这块处理没太看懂,利用方差计算出bbox的宽度和高度
moment_transfer = (self.moment_transfer * self.moment_mul) + (
self.moment_transfer.detach() * (1 - self.moment_mul))
moment_width_transfer = moment_transfer[0]
moment_height_transfer = moment_transfer[1]
half_width = pts_x_std * torch.exp(moment_width_transfer)
half_height = pts_y_std * torch.exp(moment_height_transfer)
#类似于 均值加减标准差
bbox = torch.cat([
pts_x_mean - half_width, pts_y_mean - half_height,
pts_x_mean + half_width, pts_y_mean + half_height
],dim=1)
else:
raise NotImplementedError
return bbox
# 从回归的bbox 生成下轮的 reppoints
# 之前一直不理解,reppoint 的学习不应该是直接学习 点的位置吗,为什么还要借助bbox,不是有点多此一举吗?
# 后来看到论文 在计算loss的时候,是先将reppints转换成 伪框,去和GT之间计算损失,
# 所以在下一轮生成reppoints的时候,必须从回归的bbox生成,不知道理解的对不对
def gen_grid_from_reg(self, reg, previous_boxes):
"""
Base on the previous bboxes and regression values, we compute the
regressed bboxes and generate the grids on the bboxes.
:param reg: the regression value to previous bboxes.
:param previous_boxes: previous bboxes.
:return: generate grids on the regressed bboxes.
"""
b, _, h, w = reg.shape
bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
bwh = (previous_boxes[:, 2:, ...] -
previous_boxes[:, :2, ...]).clamp(min=1e-6)
grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
reg[:, 2:, ...])
grid_wh = bwh * torch.exp(reg[:, 2:, ...])
grid_left = grid_topleft[:, [0], ...]
grid_top = grid_topleft[:, [1], ...]
grid_width = grid_wh[:, [0], ...]
grid_height = grid_wh[:, [1], ...]
intervel = torch.linspace(0., 1., self.dcn_kernel).view(
1, self.dcn_kernel, 1, 1).type_as(reg)
grid_x = grid_left + grid_width * intervel
grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
grid_x = grid_x.view(b, -1, h, w)
grid_y = grid_top + grid_height * intervel
grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
grid_y = grid_y.view(b, -1, h, w)
grid_yx = torch.stack([grid_y, grid_x], dim=2)
grid_yx = grid_yx.view(b, -1, h, w)
regressed_bbox = torch.cat([
grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
], 1)
return grid_yx, regressed_bbox
# Head 部分基本结构
def forward_single(self, x):
dcn_base_offset = self.dcn_base_offset.type_as(x)
# If we use center_init, the initial reppoints is from center points.
# If we use bounding bbox representation, the initial reppoints is
# from regular grid placed on a pre-defined bbox.
#为了做对比试验?? bbox representation && reppoints representation
if self.use_grid_points or not self.center_init:
scale = self.point_base_scale / 2
points_init = dcn_base_offset / dcn_base_offset.max() * scale
bbox_init = x.new_tensor([-scale, -scale, scale,
scale]).view(1, 4, 1, 1)
else:
points_init = 0
cls_feat = x
pts_feat = x
#分类支路上的三个 3*3卷积
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
#定位支路上的三个 3*3卷积
for reg_conv in self.reg_convs:
pts_feat = reg_conv(pts_feat)
# initialize reppoints
# 连着两个 3*3卷积,输出便是图中的offset1的位置
pts_out_init = self.reppoints_pts_init_out(
self.relu(self.reppoints_pts_init_conv(pts_feat)))
if self.use_grid_points:
pts_out_init, bbox_out_init = self.gen_grid_from_reg(
pts_out_init, bbox_init.detach())
else:
pts_out_init = pts_out_init + points_init
# refine and classify reppoints
pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
) + self.gradient_mul * pts_out_init
#图中的offset1
dcn_offset = pts_out_init_grad_mul - dcn_base_offset
#head部分最后的输出
cls_out = self.reppoints_cls_out(
self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
pts_out_refine = self.reppoints_pts_refine_out(
self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
if self.use_grid_points:
pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
pts_out_refine, bbox_out_init.detach())
#原图中最后有一步相加
else:
pts_out_refine = pts_out_refine + pts_out_init.detach()
return cls_out, pts_out_init, pts_out_refine
# 利用multi_apply函数 对于每个feature map都执行forward_single函数
def forward(self, feats):
return multi_apply(self.forward_single, feats)
#根据 每个feature_map的大小,生成相应大小的点集
def get_points(self, featmap_sizes, img_metas):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: points of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes) # feature_map的数量
# since feature map sizes of all images are the same, we only compute
# points center for one time
multi_level_points = []
for i in range(num_levels):
# 对于每个feature_map ,生成相应数量的点集(先是所有的x,再是所有的y)
points = self.point_generators[i].grid_points(
featmap_sizes[i], self.point_strides[i])
#添加进list
multi_level_points.append(points)
points_list = [[point.clone() for point in multi_level_points]
for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level grids
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
point_stride = self.point_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
#计算feature_map的有效H,W
#这里的有效是指 当stride不能被H,W整除的时候,我们在 其向上取整 和 feature_map 中选择更小的
#也许是为了 生成的flag 标志位 都在 feature_map 之中
valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h) #np.ceil() 计算大于等于输入值的最小整数
valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
flags = self.point_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return points_list, valid_flag_list
#论文中提到过 利用center point来初始化的表示一个目标
#具体有点没看懂,再结合loss看一下
def centers_to_bboxes(self, point_list):
"""Get bboxes according to center points. Only used in MaxIOUAssigner.
"""
bbox_list = []
for i_img, point in enumerate(point_list):
bbox = []
for i_lvl in range(len(self.point_strides)):
scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
bbox_shift = torch.Tensor([-scale, -scale, scale,
scale]).view(1, 4).type_as(point[0])
bbox_center = torch.cat(
[point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center + bbox_shift)
bbox_list.append(bbox)
return bbox_list
#利用 center point 和 预测的
#还是利用offset 计算 下一轮的reppoints
#结合loss看一下
def offset_to_pts(self, center_list, pred_list):
"""Change from point offset to point coordinate.
"""
pts_list = []
for i_lvl in range(len(self.point_strides)):
pts_lvl = []
for i_img in range(len(center_list)):
pts_center = center_list[i_img][i_lvl][:, :2].repeat(
1, self.num_points)
pts_shift = pred_list[i_lvl][i_img]
yx_pts_shift = pts_shift.permute(1, 2, 0).view(
-1, 2 * self.num_points)
y_pts_shift = yx_pts_shift[..., 0::2]
x_pts_shift = yx_pts_shift[..., 1::2]
xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
pts_lvl.append(pts)
pts_lvl = torch.stack(pts_lvl, 0)
pts_list.append(pts_lvl)
return pts_list
def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
label_weights, bbox_gt_init, bbox_weights_init,
bbox_gt_refine, bbox_weights_refine, stride,
num_total_samples_init, num_total_samples_refine):
# classification loss
labels = labels.reshape(-1)
#label_weight 是什么??
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
#论文中 是 focal loss
loss_cls = self.loss_cls(
cls_score,
labels,
label_weights,
avg_factor=num_total_samples_refine)
# points loss
bbox_gt_init = bbox_gt_init.reshape(-1, 4)
bbox_weights_init = bbox_weights_init.reshape(-1, 4)
bbox_pred_init = self.points2bbox(
pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
bbox_pred_refine = self.points2bbox(
pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
normalize_term = self.point_base_scale * stride
loss_pts_init = self.loss_bbox_init(
bbox_pred_init / normalize_term,
bbox_gt_init / normalize_term,
bbox_weights_init,
avg_factor=num_total_samples_init)
loss_pts_refine = self.loss_bbox_refine(
bbox_pred_refine / normalize_term,
bbox_gt_refine / normalize_term,
bbox_weights_refine,
avg_factor=num_total_samples_refine)
return loss_cls, loss_pts_init, loss_pts_refine
def loss(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.point_generators)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
# target for initial stage
# Initial_stage_LOSS前的准备工作:
#1.根据feature_mapsize生成目标center的坐标
#2.根据center 和 补偿值 计算 reppoints 的坐标
#3.选择回归方式(bbox or reppoints),如果选择bbox,则需要将上面计算的坐标转换成伪框(pseudo box)
#4.获取上面得到的proposal 所对应的GT 和分类标签,以计算loss
#5.最后获取proposal中的正样本数量和负样本数量,以及他们之和(总样本数量)
center_list, valid_flag_list = self.get_points(featmap_sizes,
img_metas)
pts_coordinate_preds_init = self.offset_to_pts(center_list,
pts_preds_init)
#
if cfg.init.assigner['type'] == 'PointAssigner':
# Assign target for center list
candidate_list = center_list
else:
# transform center list to bbox list and
# assign target for bbox list
bbox_list = self.centers_to_bboxes(center_list)
candidate_list = bbox_list
#得到每个proposal相应的GT 和 分类标签
cls_reg_targets_init = point_target(
candidate_list,
valid_flag_list,
gt_bboxes,
img_metas,
cfg.init,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
# num_total_pos_init 正样本数量
# num_total_neg_init 负样本数量
(*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
num_total_pos_init, num_total_neg_init) = cls_reg_targets_init
# 总样本数量 = 正样本 + 负样本
num_total_samples_init = (
num_total_pos_init +
num_total_neg_init if self.sampling else num_total_pos_init)
# target for refinement stage
# Refinement_stage_loss前的准备工作:
# 与前面5步基本类似,只有第三步不一样:不再选择回归方式,而是直接将reppoints的点集转换成pseudo bbox,以计算损失。
center_list, valid_flag_list = self.get_points(featmap_sizes,
img_metas)
pts_coordinate_preds_refine = self.offset_to_pts(
center_list, pts_preds_refine)
bbox_list = []
for i_img, center in enumerate(center_list):
bbox = []
for i_lvl in range(len(pts_preds_refine)):
#将reppoints----> bbox 来计算loss
bbox_preds_init = self.points2bbox(
pts_preds_init[i_lvl].detach())
#将feature map 的 bbox 映射到原图(直接乘以相应的stride)
bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
bbox_center = torch.cat(
[center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center +
bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
#对应的是每张图片的 所有的 bbox
bbox_list.append(bbox)
# 和上面类似
cls_reg_targets_refine = point_target(
bbox_list,
valid_flag_list,
gt_bboxes,
img_metas,
cfg.refine,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
(labels_list, label_weights_list, bbox_gt_list_refine,
candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
num_total_neg_refine) = cls_reg_targets_refine
num_total_samples_refine = (
num_total_pos_refine +
num_total_neg_refine if self.sampling else num_total_pos_refine)
# compute loss
#其实使用的还是single_loss函数
losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
self.loss_single, #让后面每一个参数的list,都通过loss_single函数
cls_scores,
pts_coordinate_preds_init,
pts_coordinate_preds_refine,
labels_list,
label_weights_list,
bbox_gt_list_init,
bbox_weights_list_init,
bbox_gt_list_refine,
bbox_weights_list_refine,
self.point_strides,
num_total_samples_init=num_total_samples_init,
num_total_samples_refine=num_total_samples_refine)
loss_dict_all = {
'loss_cls': losses_cls,
'loss_pts_init': losses_pts_init,
'loss_pts_refine': losses_pts_refine
}
return loss_dict_all
def get_bboxes(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
img_metas,
cfg,
rescale=False,
nms=True):
assert len(cls_scores) == len(pts_preds_refine)
bbox_preds_refine = [
self.points2bbox(pts_pred_refine)
for pts_pred_refine in pts_preds_refine
]
num_levels = len(cls_scores)
mlvl_points = [
self.point_generators[i].grid_points(cls_scores[i].size()[-2:],
self.point_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds_refine[i][img_id].detach()
for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale, nms)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_points,
img_shape,
scale_factor,
cfg,
rescale=False,
nms=True):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_bboxes = []
mlvl_scores = []
for i_lvl, (cls_score, bbox_pred, points) in enumerate(
zip(cls_scores, bbox_preds, mlvl_points)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
points = points[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
bboxes = bbox_pred * self.point_strides[i_lvl] + bbox_pos_center
x1 = bboxes[:, 0].clamp(min=0, max=img_shape[1])
y1 = bboxes[:, 1].clamp(min=0, max=img_shape[0])
x2 = bboxes[:, 2].clamp(min=0, max=img_shape[1])
y2 = bboxes[:, 3].clamp(min=0, max=img_shape[0])
bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
if nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores
还有几个点想分享一下:
在mmdetecion里有一个常用的函数: multi_apply(), 感觉这篇博客讲的不错点击这里
还有一个就是这份代码我并没有跑通,可能当时这份代码是在mmdetecion1.x上写的,而现在已经更新到了2.3了(不知道能不能安装以前的版本),有些函数已经变了或者消失了,如果哪位大佬可以跑通的话,希望分享一下,大家一起进步,谢谢了。