官网介绍
部分参数介绍内容:
Arguments:
input (Tensor[N, C, H, W]): input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from. If a single Tensor is passed,
then the first column should contain the batch index. If a list of Tensors
is passed, then each Tensor will correspond to the boxes for an element i
in a batch
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
之前从来没有尝试过就直接替换了老项目那种需要c编译的roi_pool
/roi_align
模块。打出来一看才知道之这样orz。
原来batch_id
在第0维…且数值为绝对位置。其格式应为[batch_id, x1, y1, x2, y2]
,其中(x1, y1)
为左上角,(x2, y2)
为右下角。
from torchvision.ops import nms, roi_align, roi_pool
import torch
# fp = torch.randn([1, 1, 5, 5])
fp = torch.tensor(list(range(5 * 5))).float()
fp = fp.view(1, 1, 5, 5)
print(fp)
# [batch_id, x1, y1, x2, y2]
boxes = torch.tensor([[0, 0, 0, 1, 1]]).float()
pooled_features = roi_align(fp, boxes, [4, 4])
print(pooled_features)
pooled_features = roi_pool(fp, boxes, [4, 4])
print(pooled_features)
tensor([[[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]]]])
tensor([[[[0.7500, 1.0000, 1.2500, 1.5000],
[2.0000, 2.2500, 2.5000, 2.7500],
[3.2500, 3.5000, 3.7500, 4.0000],
[4.5000, 4.7500, 5.0000, 5.2500]]]])
tensor([[[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[5., 5., 6., 6.],
[5., 5., 6., 6.]]]])