pytorch每日一学49(torch.where())根据t指定条件更改指定tensor中的数值

第49个方法

torch.where(condition, x, y) → Tensor

此方法是将x中的元素和条件相比,如果符合条件就还等于原来的元素,如果不满足条件的话,那就去y中对应的值,公式为
pytorch每日一学49(torch.where())根据t指定条件更改指定tensor中的数值_第1张图片
举个例子就清楚了,官方例子如下:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)

其中第一个例子,取x大于零,那么x中小于0的值就会变成y中与之相对应位置的值,由于y里面都是1,所以x中小于0的位置都变成了1。

第二个例子条件仍然为x>0,但是这时的y是0,那么x中小于0的值就都变成了0。

其实是有对应位置关系的,x中元素不满足条件时,取得是y中对应的元素来填充里面的值。如下所示:
pytorch每日一学49(torch.where())根据t指定条件更改指定tensor中的数值_第2张图片
a中第一个小于零的值为-1,被b中对应的值3直接填补,而-2被7填补,所以填补是有对应关系的,所以x和y必须是可广播的,否则就会报错。

标量和张量也可以组合,就如同官方的第二个示例,其中x就是张量,而y就是标量。目前有效地标量和张量的组合为:

  • float型和double型
  • integral型和long型
  • complex型和complex 128型

而torch.where(condition)与torch.nonzero(condition, as tuple=True)相等,就是返回了其中满足条件的元素的位置,不过是以元组的形式返回的,第一个元组代表了第一个维度的位置,第二个代表第二维,一次类推,如下所示:

pytorch每日一学49(torch.where())根据t指定条件更改指定tensor中的数值_第3张图片
其中取了a中大于0的元素,返回了它们的位置,取第一个里面的为第一维,第二个里面的为第二维,即位置为(0, 0), (0, 3)…,详情请看torch.nonzero()

你可能感兴趣的:(pytorch每日一学,深度学习,pytorch,机器学习,数据挖掘,神经网络)