高阶操作(where和gather函数)

1、where

torch.where(condition, x, y) → Tensor
三个参数,condition是选择条件,x是第一个源头1,y是源头2

原理:当满足condition条件时,输出tensor对应位置为x中该位置元素,反之为y中该位置元素。
高阶操作(where和gather函数)_第1张图片

cond = torch.tensor([[0.6, 0.7],[0.8, 0.2]])

a = torch.zeros(2,2)
b = torch.ones(2,2)

torch.where(cond>0.5, a, b)

高阶操作(where和gather函数)_第2张图片

2、gather

torch.gather(input ,dim, index)
其中,input是储存待取元素的tensor,
	 dim是操作的维度, 
  	 index里各元素是索引值,从input中取每个索引值对应的元素

实例理解:
神经网络作手写数字分类(这里数字设为100-109),模型输出结果shape:(4,10),其中10表示每个目标对应十个数字分别的预测概率大小。想得到最终的结果是,对每个目标,预测最可能的三个数字结果,故最终shape为(4,3)

prob = torch.rand(4, 10)

idx = prob.topk(3, dim=1)
idx

idx = idx[1]
idx

label = torch.arange(10)+100

torch.gather(label.expand(4, 10), dim=1, index=idx.long())
# .long()函数作用是向下取整,得到的不是float格式。在这里用不用效果都一样

高阶操作(where和gather函数)_第3张图片
高阶操作(where和gather函数)_第4张图片
最终的结果含义是:比如第一行,对第一个目标预测,该目标最可能的数字是106、102和107中的某一个。

你可能感兴趣的:(pytorch,人工智能,python)