yolov1代码中的coor_mask.unsqueeze(-1) & coor_mask.expand_as(tensor)

import torch
tensor = torch.randn(1,7,7,30)
print(tensor)
print(tensor.shape)

torch.Size([1, 7, 7, 30])

coor_mask = tensor[:, :, :, 4] > 1
print(coor_mask)
print(coor_mask.shape)
tensor([[[ True, False,  True, False,  True, False, False],
         [False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False],
         [False,  True, False, False, False,  True, False],
         [False,  True, False, False, False,  True,  True],
         [False, False, False, False, False, False, False]]])
torch.Size([1, 7, 7])

coor_mask.unsqueeze(-1)

# 可以理解为我有一个立方体,这个立方体x,y轴上大小都是7,在z轴上是30
coor_mask = coor_mask.unsqueeze(-1)
print(coor_mask)
print(coor_mask.shape)
torch.Size([1, 7, 7, 1])

coor_mask.expand_as(tensor)

coor_mask = coor_mask.expand_as(tensor)
print(coor_mask)
print(coor_mask.shape)
tensor([[[[ True,  True,  True,  ...,  True,  True,  True],
          [False, False, False,  ..., False, False, False],
          [ True,  True,  True,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ...,  True,  True,  True],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         ...,

         [[False, False, False,  ..., False, False, False],
          [ True,  True,  True,  ...,  True,  True,  True],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [ True,  True,  True,  ...,  True,  True,  True],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [ True,  True,  True,  ...,  True,  True,  True],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [ True,  True,  True,  ...,  True,  True,  True],
          [ True,  True,  True,  ...,  True,  True,  True]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]]])
torch.Size([1, 7, 7, 30])

你可能感兴趣的:(目标检测,YOLO,python,pytorch)