where和gather

本篇介绍两个高阶操作where和gather

where

  • torch.where(condition, x, y) -> Tensor

gather

应用场景

在总共10类的分类任务中,如果真实标签是100到109,而不是0到1。但是网络的预测值却是0到1,那我怎么将预测值和真实值映射到一起呢?

有人可能就说,我让预测值都加100,或者真实标签都减100,不就行了嘛。

但如果真实标签和预测值之间不是这种简单的映射关系呢,比如是[100, 102, 105, 106, 109, 110, 200, 400, 900, 10000]呢,当然,你也可以粗暴地找个映射关系。

但最优雅的方式还是使用这里的gather 方法

  • 这里虽然label中的值千奇百怪,但是这些值的索引还是0到10,这样就可以简单映射了。

你可能感兴趣的:(where和gather)