pytorch使用 ROIalign 代码实例

batch_size = x.shape[0]
all_roi_align_feats = []
for i in range(batch_size):
    hmap1_s = out_hmap_1[i, :, :, :]
    hmap1_s = hmap1_s.unsqueeze(0)
    # print('hmap1_s: ', hmap1_s.shape)
    regs1_s = regs[i, :, :, :]
    regs1_s = regs1_s.unsqueeze(0)

    dets = ctdet_decode(hmap1_s, regs1_s, K=100)
    # print('dets: ', dets.shape)

    roi_align_feats_0 = []
    roi_align_feats_1 = []
    roi_align_feats_2 = []
    for index, box in enumerate(dets):
        box_list = []
        # print(dets[0, index, :4])
        box_list.append(dets[:, index, :4])
        cls_index = int(dets[:, index, -1].cpu().detach().numpy()[0])
        # print('cls_index: ', cls_index)
        roi = torchvision.ops.roi_align(input=hmap1_s[:, cls_index, :, :].unsqueeze(0), boxes=box_list,
                                        output_size=(128, 128))

        if cls_index == 0:
            roi_align_feats_0.append(roi)
            # print('roi_align_feats_0.append(roi)')
        elif cls_index == 1:
            roi_align_feats_1.append(roi)
            # print('roi_align_feats_1.append(roi)')
        elif cls_index == 2:
            roi_align_feats_2.append(roi)
            # print('roi_align_feats_2.append(roi)')

    if len(roi_align_feats_0):
        roi_align_feats_0 = torch.cat(roi_align_feats_0, 0)
    else:
        roi_align_feats_0 = torch.zeros((1, 1, 128, 128))

    if len(roi_align_feats_1):
        roi_align_feats_1 = torch.cat(roi_align_feats_1, 0)
    else:
        roi_align_feats_1 = torch.zeros((1, 1, 128, 128))

    if len(roi_align_feats_2):
        roi_align_feats_2 = torch.cat(roi_align_feats_2, 0)
    else:
        roi_align_feats_2 = torch.zeros((1, 1, 128, 128))

    # roi_align_feats_1 = torch.cat(roi_align_feats_1, 0)
    # roi_align_feats_2 = torch.cat(roi_align_feats_2, 0)
    roi_align_feats_0 = roi_align_feats_0.cuda()
    roi_align_feats_1 = roi_align_feats_1.cuda()
    roi_align_feats_2 = roi_align_feats_2.cuda()
    roi_align_feats = torch.cat((roi_align_feats_0, roi_align_feats_1, roi_align_feats_2), 1)

    roi_align_feats += hmap1_s
    # print('roi_align_feats:', roi_align_feats.shape)
    all_roi_align_feats.append(roi_align_feats)

all_roi_align_feats = torch.cat(all_roi_align_feats, 0)

​

你可能感兴趣的:(大数据)