import torch
tensor = torch.randn(1,7,7,30)
print(tensor)
print(tensor.shape)
torch.Size([1, 7, 7, 30])
coor_mask = tensor[:, :, :, 4] > 1
print(coor_mask)
print(coor_mask.shape)
tensor([[[ True, False, True, False, True, False, False],
[False, False, False, False, False, False, False],
[False, False, False, False, False, False, False],
[False, False, False, False, False, False, False],
[False, True, False, False, False, True, False],
[False, True, False, False, False, True, True],
[False, False, False, False, False, False, False]]])
torch.Size([1, 7, 7])
# 可以理解为我有一个立方体,这个立方体x,y轴上大小都是7,在z轴上是30
coor_mask = coor_mask.unsqueeze(-1)
print(coor_mask)
print(coor_mask.shape)
torch.Size([1, 7, 7, 1])
coor_mask = coor_mask.expand_as(tensor)
print(coor_mask)
print(coor_mask.shape)
tensor([[[[ True, True, True, ..., True, True, True],
[False, False, False, ..., False, False, False],
[ True, True, True, ..., True, True, True],
...,
[ True, True, True, ..., True, True, True],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
...,
[[False, False, False, ..., False, False, False],
[ True, True, True, ..., True, True, True],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[ True, True, True, ..., True, True, True],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[ True, True, True, ..., True, True, True],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]]])
torch.Size([1, 7, 7, 30])