Faster R-CNN中RPN网络anchor box的计算(代码演示)

详见代码:

import numpy as np
import matplotlib.pyplot as plt
import cv2
import matplotlib.patches as patches
from PIL import Image

# 原图片的大小,宽为W,高为H
W = 256
H = 256

# 下采样的倍数
rpn_stride = 8  # times downsampling

#  Conv提取特征后feature maps的宽和高
w = W/rpn_stride
h = H/rpn_stride

# scale 和 ratios(其中,scale为anchor box的宽和高之和。ratio为之比)
scales = [3 ,5 ,9]    # sum of w and h
ratios = [0.5, 1 ,2 ] # 3 ratios

def anchor(w,h,rpn_stride,scales,ratios):
    '''
    input : feature maps的w和h
            rpn_stride 下采样的倍数,用于映射anchor boxs到原图
            scales,ratios anchor box的设置
    output :numpy.ndarray shape=(w*h*k,4)
    '''

    # combinations of scales and ratios
    scales , ratios = np.meshgrid(scales,ratios)
    scales , ratios = scales.flatten() , ratios.flatten()

    # calculating w,h of anchors
    anchorbox_Ws = scales * np.sqrt(ratios)
    anchorbox_Hs = scales / np.sqrt(ratios)

    # mapping anchor porints to raw input 
    raw_xs = np.arange(0,w) * rpn_stride
    raw_ys = np.arange(0,h) * rpn_stride

    ###############################################################################
    # combinations of anchor points in raw input 
    raw_xs , raw_ys = np.meshgrid(raw_xs,raw_ys)

    # 9 anchor boxs for each anchor points
    centerXs , anchorbox_Ws = np.meshgrid(raw_xs ,anchorbox_Ws)
    centerYs , anchorbox_Hs = np.meshgrid(raw_ys ,anchorbox_Hs)

    anchor_center = np.stack([centerYs,centerXs],axis=2).reshape(-1,2)
    anchor_size = np.stack([anchorbox_Hs,anchorbox_Ws],axis=2).reshape(-1,2)
    ###############################################################################
    
    # upper left ,low right
    boxes = np.concatenate([anchor_center-0.5*anchor_size ,anchor_center+0.5*anchor_size],axis=1)
    return boxes

anchors = anchor(w,h,rpn_stride,scales,ratios)
print(anchors.shape)

'''
# cv2.imread
img=cv2.imread('timg.jpg')
img=cv2.resize(img,(W,H))
cv2.imshow('img',img)
'''

plt.figure(figsize=(10,10))
#img=Image.open('timg.jpg')
img=cv2.imread('timg.jpg',cv2.IMREAD_COLOR)
img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img=cv2.resize(img,(W,H))
plt.imshow(img)

asx = plt.gca()

for i in range(anchors.shape[0]):
    box = anchors[i]
    rec = patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],edgecolor='r',facecolor='none')
    asx.add_patch(rec)
plt.show()


labeled by anchor boxes:
Faster R-CNN中RPN网络anchor box的计算(代码演示)_第1张图片

你可能感兴趣的:(深度学习)