本节主要介绍分步训练中利用训练好的RPN网络生成proposal。
产品proposal的网络结构如下:
第二步利用训练好的rpn网络产生proposal,代码如下:
## 第二步,主要是利用第一步训练好的RPN网络来生成proposal
print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
print 'Stage 1 RPN, generate proposals'
print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
mp_kwargs = dict(
queue=mp_queue,
imdb_name=args.imdb_name,
rpn_model_path=str(rpn_stage1_out['model_path']),
cfg=cfg,
rpn_test_prototxt=rpn_test_prototxt)
p = mp.Process(target=rpn_generate, kwargs=mp_kwargs) #rpn_generate()产生proposal
p.start() #开始生成proposal
rpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']
p.join()
def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,
rpn_test_prototxt=None):
"""Use a trained RPN to generate proposals.
"""
cfg.TEST.RPN_PRE_NMS_TOP_N = -1 # no pre NMS filtering,不使用NMS
cfg.TEST.RPN_POST_NMS_TOP_N = 2000 # limit top boxes after NMS 使用nms后产生2000个bbox
print 'RPN model: {}'.format(rpn_model_path)
print('Using config:')
pprint.pprint(cfg)
import caffe
_init_caffe(cfg)
# NOTE: the matlab implementation computes proposals on flipped images, too.
# We compute them on the image once and then flip the already computed
# proposals. This might cause a minor loss in mAP (less proposal jittering).
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)
# Load RPN and configure output directory
rpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST) #加载RPN网络
output_dir = get_output_dir(imdb) #显示输出目录
print 'Output will be saved to `{:s}`'.format(output_dir)
# Generate proposals on the imdb
rpn_proposals = imdb_proposals(rpn_net, imdb) #使用imdb_proposals()在所有图片上产生proposals
# Write proposals to disk and send the proposal file path through the
# multiprocessing queue
rpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]
rpn_proposals_path = os.path.join(
output_dir, rpn_net_name + '_proposals.pkl') #将proposal的文件路径放入多线程中
with open(rpn_proposals_path, 'wb') as f: #以二进制写模式打开
cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL) #将之前生成的proposal序列化并存储到之前设置好的路径中
print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)
queue.put({'proposal_path': rpn_proposals_path})#proposal序列化后存入队列,供其他进程使用
首先设置了预 NMS,还有就是经过 NMS 后产生 2000 个 proposals,然后初始化 caffe,再用 get_imdb() 函数得到 imdb 数据,方法和前面一样,再用 caffe.NET()加载 RPN 网络,再使用 imdb_proposals() 得到 proposal,那我们就进入这个函数:
def imdb_proposals(net, imdb):
"""Generate RPN proposals on all images in an imdb.
在所有图片上生成proposal
"""
_t = Timer() #产生一个时钟对象
imdb_boxes = [[] for _ in xrange(imdb.num_images)]
for i in xrange(imdb.num_images):
im = cv2.imread(imdb.image_path_at(i)) #读取图片
_t.tic()
imdb_boxes[i], scores = im_proposals(net, im) #在单张图像上获得proplsals,包括boxes和scores,注意RPN中的cls只做一个二分类任务
_t.toc()
print 'im_proposals: {:d}/{:d} {:.3f}s' \
.format(i + 1, imdb.num_images, _t.average_time)
if 0:
dets = np.hstack((imdb_boxes[i], scores))
# from IPython import embed; embed()
_vis_proposals(im, dets[:3, :], thresh=0.9)
plt.show()
return imdb_boxes
该函数的作用就是在所有的图片上生成 proposal,不过作者又嵌套了一个 im_proposals() 函数,即在一张图片上产生 proposals,进入 im_proposals() 函数中:
def im_proposals(net, im):
"""Generate RPN proposals on a single image."""
blobs = {}
blobs['data'], blobs['im_info'] = _get_image_blob(im) #将图片转换为blob的格式
net.blobs['data'].reshape(*(blobs['data'].shape)) #将网络中的blob对应的结构相应的进行修改
net.blobs['im_info'].reshape(*(blobs['im_info'].shape))
blobs_out = net.forward(
data=blobs['data'].astype(np.float32, copy=False),
im_info=blobs['im_info'].astype(np.float32, copy=False)) #进行一次前向传播
scale = blobs['im_info'][0, 2] #获得缩放比例
boxes = blobs_out['rois'][:, 1:].copy() / scale #这个产生的boxes是对应与原图的尺寸
scores = blobs_out['scores'].copy()
return boxes, scores # 获取boxes和scores
首先用_get_image_blob() 函数将图片数据转换为 caffe 的 blob 格式,进入该函数:
def _get_image_blob(im):
"""Converts an image into a network input.
将输入的RGB图像转化为网络的输入格式
Arguments:
im (ndarray): a color image in BGR order
Returns:
blob (ndarray): a data blob holding an image pyramid
im_scale_factors (list): list of image scales (relative to im) used
in the image pyramid
"""
im_orig = im.astype(np.float32, copy=True) #实现变量类型转换,im_org是像素矩阵
im_orig -= cfg.PIXEL_MEANS #减去像素均值
im_shape = im_orig.shape #获得图像的像素尺寸
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
processed_ims = []
assert len(cfg.TEST.SCALES) == 1 #确保测试时只有一种图像尺寸
target_size = cfg.TEST.SCALES[0] #图片的最短边
im_scale = float(target_size) / float(im_size_min) #变换后的最短边除以原始图像的最短边得到缩放比例
# Prevent the biggest axis from being more than MAX_SIZE
if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE: # 如果scale处理后的图像最大边大于要求的最大值MAX_SIZE,则修改放大比例
im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR) #利用双线性插值进行图像缩放,按比例
im_info = np.hstack((im.shape[:2], im_scale))[np.newaxis, :] #使用np.newaxis得到in_info,格式为【M,N,im_scale】
processed_ims.append(im) #将调整好大小后的图片添加到processed_ims中
# Create a blob to hold the input images
blob = im_list_to_blob(processed_ims) #blob格式为(batch elem,channel,height,width)
return blob, im_info
上述将缩放后的图像调用im_list_to_blob函数转换成blob格式的图像,进入该函数
def im_list_to_blob(ims):
"""Convert a list of images into a network input.
Assumes images are already prepared (means subtracted, BGR order, ...).
"""
max_shape = np.array([im.shape for im in ims]).max(axis=0)
num_images = len(ims)
blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
dtype=np.float32)
for i in xrange(num_images):
im = ims[i]
blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
# Move channels (axis 3) to axis 1
# Axis order will become: (batch elem, channel, height, width)
channel_swap = (0, 3, 1, 2)
blob = blob.transpose(channel_swap)
return blob
最终得到的 blob 格式为 (batch elem , channel , height , width),im_info 格式为[M,N,im_scale],其中 im_scale 是缩放比例,原始图片输入 faster rcnn 中进行训练时都需要先缩放成统一的规格;再回到 im_proposals() 函数中,使用 net.forward()函数进行一次前向传播,获得blobs_out为计算得到的proposal,数据格式如下:
之后再回到 imdb_proposals()函数中,最后返回得到的 imdb_boxes, 即我们从 RPN 上产生的 proposals。再回到 rpn_generate()函数中,变量rpn_proposals值如下图,测试中用到了5张图,因此是一个长为5的list,每个list的结构为(2000,4)即每张图有2000个预测框。
接着就是将生成的 proposals 保存并传输到多线程中去供下一步训练使用,这个函数使命就暂时完成了;