pytorch常用mask命令

文章目录

  • 前言
  • 1.Tensor.masked_fill_(mask, value)
    • 举个例子
  • 2.torch.masked_select(input, mask, *, out=None) → Tensor
    • 举个例子
  • 3.Tensor.masked_scatter_(mask, source)
    • 举个例子


前言

pytorch常用mask命令_第1张图片
mask是深度学习里面常用的操作,最近在研究transformer的pytorch代码,总能看到各种mask的命令,在这里总结一下

1.Tensor.masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.

Parameters
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with

举个例子

import torch
mask = torch.tensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]]).bool()
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(3,3)
a.masked_fill(mask, 0)
# tensor([[ 0.0000,  0.6781,  0.6532],
#         [-1.2078,  0.0000,  0.4964],
#         [ 0.2192, -0.6276,  0.0000]])
a.masked_fill(~mask, 0)#可以对mask取反
# tensor([[-0.4438,  0.0000,  0.0000],
#         [ 0.0000,  1.3907,  0.0000],
#         [ 0.0000,  0.0000,  2.2462]])

2.torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

(注意)The returned tensor does not use the same storage as the original tensor

Parameters
input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with

举个例子

import torch
x = torch.randn(3,4)
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
mask = x > 0.5
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
torch.masked_select(x, mask)
# tensor([3.0989, 1.9527, 0.8310])

3.Tensor.masked_scatter_(mask, source)

Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask

source大小和mask至少一样,能够被广播到Tensor上,或者source和Tensor一样
作用就是把source里mask是true的位置挑出来给Tensor

Parameters
mask (BoolTensor) – the boolean mask
source (Tensor) – the tensor to copy from

举个例子

import torch
mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]])
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(2,3,3)
s = torch.ones_like(a)
a.masked_scatter(mask, s)
# tensor([[[ 1.0000, -0.1560, -0.7760],
#          [-0.5192,  1.0000, -0.1709],
#          [ 0.2091,  0.5650,  1.0000]],

#         [[ 1.0000,  0.0623, -0.1447],
#          [-1.2910,  1.0000, -1.2722],
#          [-0.7864, -0.1118,  1.0000]]])

你可能感兴趣的:(pytorch,pytorch,深度学习,神经网络)