masked_fill_() - masked_fill() - v1.5.0

masked_fill_() - masked_fill() - v1.5.0

torch.Tensor
https://pytorch.org/docs/stable/tensors.html

  • torch.Tensor.masked_fill (Python method, in torch.Tensor)
  • torch.Tensor.masked_fill_ (Python method, in torch.Tensor)

masked_fill_(mask, value) - 函数名后面加下划线。in-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值,可以称为原地操作符。
masked_fill(mask, value) -> Tensor - 函数名后面没有下划线。out-of-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,经过复制操作,不是直接在原来的内存上改变它的值,而是修改复制的 tensor。

1. masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.
当对应位置的 mask 是 1,用 value 填充 self tensor 中的元素。

1.1 Parameters

mask (BoolTensor) – the boolean mask (元素是布尔值)
value (float) – the value to fill in with (用于填充的值)

2. masked_fill(mask, value) -> Tensor

Out-of-place version of torch.Tensor.masked_fill_()

3. example

3.1 masked_fill(mask, value) -> Tensor

(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> data = torch.randn(2, 3)
>>> data
tensor([[ 1.1389,  0.7854, -1.1975],
        [ 0.1931,  1.4460, -0.0749]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill(mask, 999)
>>> masked1
tensor([[ 9.9900e+02,  7.8542e-01,  9.9900e+02],
        [ 1.9310e-01,  9.9900e+02, -7.4897e-02]])
>>>
>>> data
tensor([[ 1.1389,  0.7854, -1.1975],
        [ 0.1931,  1.4460, -0.0749]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$

3.2 masked_fill_(mask, value)

(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> data = torch.randn(2, 3)
>>> data
tensor([[ 0.0718, -0.4983, -0.7344],
        [-2.0372, -1.6503,  1.6308]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill_(mask, 999)
>>> masked1
tensor([[ 9.9900e+02, -4.9832e-01,  9.9900e+02],
        [-2.0372e+00,  9.9900e+02,  1.6308e+00]])
>>>
>>> data
tensor([[ 9.9900e+02, -4.9832e-01,  9.9900e+02],
        [-2.0372e+00,  9.9900e+02,  1.6308e+00]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$

3.3 -np.inf

(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> data = torch.randn(2, 3)
>>> data
tensor([[ 0.3838, -0.8961,  0.4759],
        [ 0.4764, -0.2403,  0.4010]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill(mask, 0)
>>> masked1
tensor([[ 0.0000, -0.8961,  0.0000],
        [ 0.4764,  0.0000,  0.4010]])
>>>
>>> data
tensor([[ 0.3838, -0.8961,  0.4759],
        [ 0.4764, -0.2403,  0.4010]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$
(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import numpy as np
>>>
>>> data = torch.randn(2, 3)
>>> data
tensor([[5.2904e-02, 9.4895e-01, 2.6957e-01],
        [1.2166e-03, 1.2486e+00, 3.0534e+00]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill(mask, -np.inf)
>>> masked1
tensor([[      -inf, 9.4895e-01,       -inf],
        [1.2166e-03,       -inf, 3.0534e+00]])
>>>
>>> data
tensor([[5.2904e-02, 9.4895e-01, 2.6957e-01],
        [1.2166e-03, 1.2486e+00, 3.0534e+00]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$

你可能感兴趣的:(TensorFlow,-,Keras,masked_fill_,v1.5.0,masked_fill)