在R-CNN以及之后的系列文章中,都有Bounding-box Regression的使用,甚至到了MV3D等等的3D Bounding-box Regression, 其思想都是来源于最基础的Bb Regression的。我将从以下几个角度主要结合自己的理解来谈一谈Bounding-box Regression. 首先,讲一下bounding-box regression使用的动机及其解决的问题,然后分析以下这种线性回归实现的基本原理与步骤,然后分析以下R-CNN原文中此部分的代码。接着考量一下如何将其应用于3D的回归。
在我们进行了AlexNet提取特征之后,我们分析完了某个Bounding-box,得出其与Ground Truth之间的IoU大于某个阈值(如文中说得0.5),但是这个框又不是那么准确,因此我们需要对这个框进行微调,来提升我们bounding-box的IoU。
这种微调,就是通过Bounding-box Regression来进行实现的。
那么,这一微调的对象以及得到的效果具体是怎样的呢?直观而言,就是下面的这幅图表示的:
其中P是我们的原始proposal,G是Ground Truth,G拔是找到映射之后矫正得到的窗口。因此更加直接地说,我们就是要找到一种映射关系,来圈出一个G拔。
我们想要找到这种映射关系,需要找到对应的变换关系:
其中dx/y(P) 就是我们需要学习的变换规则,其判定的标准是后面将会提到了最小化损失函数。
此外,值得说明的是,如果本身P和G差距较大,也就是IoU较小,那么这种变换关系实际上是一种非线性变换,而实际上我们的这种模型是一种线性回归的模型,因此不适用于这种情况。因此,也就像上述所提到的,当IoU大于0.6时我们才执行Bounding-box Regression。
明确了目标,那么接下来我们考虑一下输入输出:
输入: Region Proposal –> P = (Px, Py, Pw, Ph)。 这是输入的表征形式,实际上真正的输入是ConvNet5之后得到的特征向量。此外还有Ground Truth t = (tx, ty, tw, th)。
输出: 我们的输出就是平移变换以及尺度缩放dx(P), dy(P), dw(P), dh(P)这四个值,然后以此得到的G拔的坐标。
再则我们来看一下损失函数的定义:
其中t*中 *表示的是:
其中波塞P表示的是proposal的特征向量,目标函数是
最后函数引入L2正则化之后的优化目标是:
通过梯度下降法或者最小二乘法就可以得到w*。
根据上述学到的参数w*, 对于测试图像,我们首先经过CNN提取特征波塞P, 然后预测变化为,最后根据(1)–(4)对窗口进行回归。
From: Github-RCNN-rcnn_predict_bbox_regressor.m
function pred_boxes = ...
rcnn_predict_bbox_regressor(model, feat, ex_boxes)
% pred_boxes = rcnn_predict_bbox_regressor(model, feat, ex_boxes)
% Predicts a new bounding box from CNN features computed on input
% bounding boxes.
%
% Inputs
% model Bounding box regressor from rcnn_train_bbox_regressor.m
% feat Input feature vectors
% ex_boxes Input bounding boxes
%
% Outputs
% pred_boxes Modified (hopefully better) ex_boxes
% AUTORIGHTS
% ---------------------------------------------------------
% Copyright (c) 2014, Ross Girshick
%
% This file is part of the R-CNN code and is available
% under the terms of the Simplified BSD License provided in
% LICENSE. Please retain this notice and LICENSE if you use
% this file (or any portion of it) in your project.
% ---------------------------------------------------------
if isempty(ex_boxes)
pred_boxes = [];
return;
end
% Predict regression targets
Y = bsxfun(@plus, feat*model.Beta(1:end-1, :), model.Beta(end, :));
% Invert whitening transformation
Y = bsxfun(@plus, Y*model.T_inv, model.mu);
% Read out predictions
dst_ctr_x = Y(:,1);
dst_ctr_y = Y(:,2);
dst_scl_x = Y(:,3);
dst_scl_y = Y(:,4);
src_w = ex_boxes(:,3) - ex_boxes(:,1) + eps;
src_h = ex_boxes(:,4) - ex_boxes(:,2) + eps;
src_ctr_x = ex_boxes(:,1) + 0.5*src_w;
src_ctr_y = ex_boxes(:,2) + 0.5*src_h;
pred_ctr_x = (dst_ctr_x .* src_w) + src_ctr_x;
pred_ctr_y = (dst_ctr_y .* src_h) + src_ctr_y;
pred_w = exp(dst_scl_x) .* src_w;
pred_h = exp(dst_scl_y) .* src_h;
pred_boxes = [pred_ctr_x - 0.5*pred_w, pred_ctr_y - 0.5*pred_h, ...
pred_ctr_x + 0.5*pred_w, pred_ctr_y + 0.5*pred_h];
From Github-RCNN-rcnn_test_bbox_regressor.m
function res = rcnn_test_bbox_regressor(imdb, rcnn_model, bbox_reg, suffix)
% AUTORIGHTS
% ---------------------------------------------------------
% Copyright (c) 2014, Ross Girshick
%
% This file is part of the R-CNN code and is available
% under the terms of the Simplified BSD License provided in
% LICENSE. Please retain this notice and LICENSE if you use
% this file (or any portion of it) in your project.
% ---------------------------------------------------------
conf = rcnn_config('sub_dir', imdb.name);
image_ids = imdb.image_ids;
% assume they are all the same
feat_opts = bbox_reg.training_opts;
num_classes = length(rcnn_model.classes);
if ~exist('suffix', 'var') || isempty(suffix)
suffix = '_bbox_reg';
else
if suffix(1) ~= '_'
suffix = ['_' suffix];
end
end
try
aboxes = cell(num_classes, 1);
for i = 1:num_classes
load([conf.cache_dir rcnn_model.classes{i} '_boxes_' imdb.name suffix]);
aboxes{i} = boxes;
end
catch
aboxes = cell(num_classes, 1);
box_inds = cell(num_classes, 1);
for i = 1:num_classes
load([conf.cache_dir rcnn_model.classes{i} '_boxes_' imdb.name]);
aboxes{i} = boxes;
box_inds{i} = inds;
clear boxes inds;
end
for i = 1:length(image_ids)
fprintf('%s: bbox reg test (%s) %d/%d\n', procid(), imdb.name, i, length(image_ids));
d = rcnn_load_cached_pool5_features(feat_opts.cache_name, ...
imdb.name, image_ids{i});
if isempty(d.feat)
continue;
end
d.feat = rcnn_pool5_to_fcX(d.feat, feat_opts.layer, rcnn_model);
d.feat = rcnn_scale_features(d.feat, feat_opts.feat_norm_mean);
if feat_opts.binarize
d.feat = single(d.feat > 0);
end
for j = 1:num_classes
I = box_inds{j}{i};
boxes = aboxes{j}{i};
if ~isempty(boxes)
scores = boxes(:,end);
boxes = boxes(:,1:4);
assert(sum(sum(abs(d.boxes(I,:) - boxes))) == 0);
boxes = rcnn_predict_bbox_regressor(bbox_reg.models{j}, d.feat(I,:), boxes);
boxes(:,1) = max(boxes(:,1), 1);
boxes(:,2) = max(boxes(:,2), 1);
boxes(:,3) = min(boxes(:,3), imdb.sizes(i,2));
boxes(:,4) = min(boxes(:,4), imdb.sizes(i,1));
aboxes{j}{i} = cat(2, single(boxes), single(scores));
if 0
% debugging visualizations
im = imread(imdb.image_at(i));
keep = nms(aboxes{j}{i}, 0.3);
for k = 1:min(10, length(keep))
if aboxes{j}{i}(keep(k),end) > -0.9
showboxes(im, aboxes{j}{i}(keep(k),1:4));
title(sprintf('%s %d score: %.3f\n', rcnn_model.classes{j}, ...
k, aboxes{j}{i}(keep(k),end)));
pause;
end
end
end
end
end
end
for i = 1:num_classes
save_file = [conf.cache_dir rcnn_model.classes{i} '_boxes_' imdb.name suffix];
boxes = aboxes{i};
inds = box_inds{i};
save(save_file, 'boxes', 'inds');
clear boxes inds;
end
end
% ------------------------------------------------------------------------
% Peform AP evaluation
% ------------------------------------------------------------------------
for model_ind = 1:num_classes
cls = rcnn_model.classes{model_ind};
try
ld = load([conf.cache_dir cls '_pr_' imdb.name suffix]);
fprintf('!!! %s : %.4f %.4f\n', cls, ld.res.ap, ld.res.ap_auc);
res(model_ind) = ld.res;
catch
res(model_ind) = imdb.eval_func(cls, aboxes{model_ind}, imdb, suffix);
end
end
fprintf('\n~~~~~~~~~~~~~~~~~~~~\n');
fprintf('Results (bbox reg):\n');
aps = [res(:).ap]';
disp(aps);
disp(mean(aps));
fprintf('~~~~~~~~~~~~~~~~~~~~\n');
3D bounding-box Regression相对于2D最大的区别就是引入了额外的z坐标以及h坐标,其他的定义方式与上述内容完全一致。
参考资料:
1. rcnn_predict_bbox_regressor.m
2. rcnn_test_bbox_regressor.m
3. http://caffecn.cn/?/question/160