【torchvision】roi_align、roi_pool使用说明

【torchvision】roi_align、roi_pool

官网介绍

部分参数介绍内容:

Arguments:
        input (Tensor[N, C, H, W]): input tensor
        boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
            format where the regions will be taken from. If a single Tensor is passed,
            then the first column should contain the batch index. If a list of Tensors
            is passed, then each Tensor will correspond to the boxes for an element i
            in a batch
        output_size (int or Tuple[int, int]): the size of the output after the cropping
            is performed, as (height, width)

之前从来没有尝试过就直接替换了老项目那种需要c编译的roi_pool/roi_align模块。打出来一看才知道之这样orz。

原来batch_id在第0维…且数值为绝对位置。其格式应为[batch_id, x1, y1, x2, y2],其中(x1, y1)为左上角,(x2, y2)为右下角。

from torchvision.ops import nms, roi_align, roi_pool
import torch

# fp = torch.randn([1, 1, 5, 5])
fp = torch.tensor(list(range(5 * 5))).float()
fp = fp.view(1, 1, 5, 5)
print(fp)
# [batch_id, x1, y1, x2, y2]
boxes = torch.tensor([[0, 0, 0, 1, 1]]).float()

pooled_features = roi_align(fp, boxes, [4, 4])
print(pooled_features)

pooled_features = roi_pool(fp, boxes, [4, 4])
print(pooled_features)

tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]]]])
tensor([[[[0.7500, 1.0000, 1.2500, 1.5000],
          [2.0000, 2.2500, 2.5000, 2.7500],
          [3.2500, 3.5000, 3.7500, 4.0000],
          [4.5000, 4.7500, 5.0000, 5.2500]]]])
tensor([[[[0., 0., 1., 1.],
          [0., 0., 1., 1.],
          [5., 5., 6., 6.],
          [5., 5., 6., 6.]]]])

你可能感兴趣的:(深度学习,pytorch,roi_align,torchvision,roi_pool,深度学习)