

  • 前言
  • 1.Tensor.masked_fill_(mask, value)
    • 举个例子
  • 2.torch.masked_select(input, mask, *, out=None) → Tensor
    • 举个例子
  • 3.Tensor.masked_scatter_(mask, source)
    • 举个例子



1.Tensor.masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.

mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with


import torch
mask = torch.tensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]]).bool()
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(3,3)
a.masked_fill(mask, 0)
# tensor([[ 0.0000,  0.6781,  0.6532],
#         [-1.2078,  0.0000,  0.4964],
#         [ 0.2192, -0.6276,  0.0000]])
a.masked_fill(~mask, 0)#可以对mask取反
# tensor([[-0.4438,  0.0000,  0.0000],
#         [ 0.0000,  1.3907,  0.0000],
#         [ 0.0000,  0.0000,  2.2462]])

2.torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

(注意)The returned tensor does not use the same storage as the original tensor

input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with


import torch
x = torch.randn(3,4)
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
mask = x > 0.5
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
torch.masked_select(x, mask)
# tensor([3.0989, 1.9527, 0.8310])

3.Tensor.masked_scatter_(mask, source)

Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask


mask (BoolTensor) – the boolean mask
source (Tensor) – the tensor to copy from


import torch
mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]])
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(2,3,3)
s = torch.ones_like(a)
a.masked_scatter(mask, s)
# tensor([[[ 1.0000, -0.1560, -0.7760],
#          [-0.5192,  1.0000, -0.1709],
#          [ 0.2091,  0.5650,  1.0000]],

#         [[ 1.0000,  0.0623, -0.1447],
#          [-1.2910,  1.0000, -1.2722],
#          [-0.7864, -0.1118,  1.0000]]])
