torch.Tensor.scatter() 类似 gather 的反向操作(gather是读出数据,scatter是写入数据),所以这里只解析torch.gather()。
gather()这个操作在功能上较为反人类,即使某段时间理解透彻了,过了几个月不碰可能又会变得生疏。官方文档对其描述也是较为简单,有些小伙伴看完可能还是不完全理解,本文从根本上去解析这个操作的功能。
概括地说,gather()是index_select()的延伸操作,比index_select()更加灵活,它的操作不属于块操作,而是元素级别的操作,所以性能上应该较低,我们应该尽可能地避免使用这个操作。
下面开始解析这个操作。
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
这个功能的设计目的是“Gathers values along an axis specified by dim.”,这是官方文档的所有描述,看到这句话做目标检测的小伙伴应该能想到这样一个场景:
目标检测网络输出矩阵,前4列是box的坐标,第5列表示检测到目标的种类标签
print(pred)
tensor([[0.0080, 0.6403, 0.9865, 0.0158, 1.0000],
[0.2742, 0.7470, 0.3837, 0.6689, 3.0000],
[0.3260, 0.6683, 0.1888, 0.9525, 0.0000],
[0.7989, 0.9154, 0.1040, 0.5538, 3.0000],
[0.6746, 0.6193, 0.0161, 0.5166, 0.0000]])
现在我们要挑选出标签是3的所有检测目标框,
i = pred[:, 4].eq(3).nonzero().repeat(1, 4)
torch.gather(pred, 0, i)
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
[0.7989, 0.9154, 0.1040, 0.5538]])
gather()可以实现实现这种整行地抽取数据,但不是最优的实现方法,我们有更合适的实现方法,index_select()和下标索引:
i = pred[:, 4].eq(3).nonzero().squeeze()
pred.index_select(0, i)[:, :4]
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
[0.7989, 0.9154, 0.1040, 0.5538]])
# 下标索引方法
pred[i, :4]
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
[0.7989, 0.9154, 0.1040, 0.5538]])
现在我们要进行更加复杂的数据抽取,输出张量的要求如下:
这时候index_selsect()无法实现,但gather()可以
index = torch.tensor([[0, 1],
[3, 2]])
torch.gather(pred, 0, index)
tensor([[0.0080, 0.7470],
[0.7989, 0.6683]])
这个操作的规则如下:
输出张量的shape和索引张量(index)相同
除了dim指示的那个维度,其他所有的维度满足条件: index.size(d) <= input.size(d)
index和输入张量input的每个维度一一对应
除了dim指示的那个维度,其他维度的input和output元素位置对应,当index.size(d) < input.size(d)时候,从最前面截取
dim指示的那个维度上数据根据index里具体元素指示的位置去定位
看起来还是不好理解的,好在这个函数的应用场景不多,到目前为止我还没遇到适合这个函数的应用场景,如果哪位小伙伴遇到了请评论区留言感激不尽。