【动手学深度学习】之锚框生成函数代码实现

【代码】:

import torch
from d2l import torch as d2l

# 更改打印设置
torch.set_printoptions(2)

def multibox_prior(data, sizes, ratios):
    """生成以每个像素为中心具有不同形状的锚框。"""
    # 图片的高和宽
    in_height, in_width = data.shape[-2:]
    # 查看操作设备(CPU/GPU),与锚框的参数
    device, num_sizes, num_ratios = data.device, len(sizes),len(ratios)
    # 每个像素生成的锚框数量
    boxes_per_pixel = (num_sizes +num_ratios -1)
    # 将锚框参数对应到操作设备(CPU/GPU)
    size_tensor = torch.tensor(sizes, device=device)
    ratio_tensor = torch.tensor(ratios,device=device)

    # 将锚点移动到像素中心,设置偏移量
    # 因为一个像素的高位1,宽为1,我们选择偏移中心0.5
    offset_h, offset_w = 0.5,0.5
    # 在y轴上缩放步长
    steps_h = 1.0/in_height
    # 在x轴上缩放步长
    steps_w = 1.0/in_width

    # 每个像素的中心
    center_h = (torch.arange(in_height,device=device) + offset_h)
    center_w = (torch.arange(in_width,device=device) +offset_w)
    shift_y, shift_x = torch.meshgrid(center_h, center_w)
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
    
    # 成成“boxes_per_pixel”个高和宽
    # 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)
    # w = s*√r*h/w
    # h = s*√r
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
                   sizes[0] * torch.sqrt(ratio_tensor[1:])))\
                       * in_height / in_width  # 处理矩形输入
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                   sizes[0] / torch.sqrt(ratio_tensor[1:])))
    
    # 除以2得到半高和半宽
    # 得到的anchor_manipulations为(xmin,xmax,ymin,ymax)
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
                                        in_height * in_width, 1) / 2
    
    # 每个中心点都有“boxes_per_pixel”个锚框
    # 生成含所有锚框中心的网络,重复“boxes_per_pixel”次
    # 得到的out_grid为每个像素点中心
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
                dim=1).repeat_interleave(boxes_per_pixel, dim=0)

    # 将每个像素点out_grid加上锚框anchor_manipulations得到输出
    output = out_grid + anchor_manipulations
    
    # 将output扩充一维,表示一个批量增加的锚框
    return output.unsqueeze(0)

# 导入图片
img = d2l.plt.imread('catdog.jpg')
# 得到图片高和宽
h, w = img.shape[:2]
# 假设batch_size=1,通道数为3
X = torch.rand(size=(1, 3, h, w))
# 令sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5]
Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])
# 得到Y的尺寸:【批量,锚框数量,4】
print(Y.shape)

【结果】:

 

(自学李沐老师《动手学深度学习》使用,仅供参考,侵权删除)

你可能感兴趣的:(动手学深度学习,深度学习,pytorch,深度学习,pytorch,人工智能,计算机视觉,神经网络)