def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
torch.where(condition , x, y)
某元素满足条件使用 x Tensor 来填充,不满足条件使用 y Tensor 来填充,其中 x 和 y 应当与原 Tensor 维度及 size 相同
cond = torch.rand(2, 2)
print(cond)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
c = torch.where(cond>0.5, a, b)
print(c)
输出:
tensor([[0.4716, 0.8124], [0.3771, 0.1771]]) tensor([[1., 0.], [1., 1.]])
def gather(input: Tensor, dim, index: Tensor) -> Tensor: ...
例子
prob = torch.randn(4, 10)
idx = prob.topk(dim=1, k=3) # 选出最有可能的 3 种
print(idx)
'''
values=tensor([[1.2655, 0.5347, 0.4686],
[1.9430, 1.1472, 1.1349],
[1.2370, 0.8487, 0.7665],
[2.0423, 2.0380, 1.0663]]),
indices=tensor([[8, 6, 7],
[8, 9, 5],
[2, 5, 7],
[5, 2, 3]]))
'''
idx = idx[1] # 每一个照片最有可能的 3 种情况
print(idx)
'''
tensor([[8, 6, 7],
[8, 9, 5],
[2, 5, 7],
[5, 2, 3]])
'''
label = torch.arange(10) * 100
print(label)
'''
tensor([ 0, 100, 200, 300, 400, 500, 600, 700, 800, 900])
'''
ret = torch.gather(label.expand(4, 10), dim=1, index=idx.long())
print(ret)
'''
tensor([[800, 600, 700],
[800, 900, 500],
[200, 500, 700],
[500, 200, 300]])
'''
图解