torch.gather/scatter_使用
说明:dim=0(按照竖直方向操作,即↓),dim=1(按照水平方向操作,即→)
gather是取的意思,scatter_是放的意思。
用法:output=torch.gather(input, dim, index)
a = torch.Tensor([[1,2],[3,4]])
b = torch.gather(a, 1, torch.LongTensor([[0,0],[1,0]]))
print(b)
这里的“取”的意思是根据gather 函数中的参数从a中取出数据放入到变量b中, dim为 1,表示水平方向从 a 中取数据,a 中的数据对应的索引为[0,1][0,1], 然后 根据取数据的索引LongTensor([[0,0],[1,0]])从a中对应的行取出数据,LongTensor([[0,0],[1,0]])表示,[0,0]表示从第一行取两次索引为0的数据1,[1,0]表示从 第二行分别取出索引为 1 和 0 的数据[4,3]。LongTensor 表示将取出的数据转化为Long型。所以 最后b为 tensor([[1.,1.],[4.,3.]])
用法:output = torch.Tensor.scatter_(dim, index, src)
a = torch.rand(2, 5)
print(a)
b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print(b)
下面是一个直观理解的图:
结合上图可以理解为,scatter_函数是从src中取出数据的,dim为 0,表示对src的索引编号是按照竖直方向的,index为torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])表示,src 中的对应的数据放到b 中的哪个位置。 用上述例子来解释就是,作为 src的数据a, tensor([[0.8294, 0.5650, 0.6233, 0.6785, 0.3359],[0.7799, 0.4003, 0.4345, 0.4769, 0.8181]]) 中的每个数据与 作为index索引数据的torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) 相对应,按照位置对应,即 a 中的第一个数据 0.8294 与 index 中的第一个数据 0 对应,且该 0 在 0 列,数据0 表明的是将 0.8294 放到 b 中的 0 位置处,这个 0 位置是要结合dim来确定的,dim 为 0, 表示竖直方向,所以就是说放在b中的 0 列 0 行,同理,index的第二个数据是 1, 它在第1列,即表明将a中的第二个数据0.5650 放到b中的第 1 列第一行。
将所有的src数据放置到新的tensor后,由于新的 tensor 维度大于src的维度,会产生空位置,空位置补 0 。
https://zhuanlan.zhihu.com/p/59346637
(1):维度dim,保持原有维度keepdim
下面通过图像的形式直观的展示了这两个参数的作用。
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True)) # dim为0,按照相同的列求和,并在结果中保留列特征
print(X.sum(dim=1, keepdim=True)) # dim为1,按照相同的行求和,并在结果中保留行特征
print(X.sum(dim=0, keepdim=False)) # dim为0,按照相同的列求和,不在结果中保留列特征
print(X.sum(dim=1, keepdim=False)) # dim为1,按照相同的行求和,不在结果中保留行特征
(2):view函数的用法
view函数的用法如下所示,就是用于改变tensor的维度。其中-1表示当前维度会根据其余指定维度自适应得到。
y = torch.LongTensor([0, 2])
print(y,y.shape)
print(y.view(-1, 1),y.view(-1, 1).shape)
#-------------------------
tensor([0, 2]) torch.Size([2])
tensor([[0],
[2]]) torch.Size([2, 1])
————————————————
(3):gather函数的用法
ganther函数的用法如下所示,用于批量取出目标tensor中对应维度的数据。
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y1 = torch.LongTensor([[0, 1, 1]])
y2 = torch.LongTensor([[1,2]])
print(y_hat.gather(0, y1.view(1, -1)))
print(y_hat.gather(1, y2.view(-1, 1)))
#---------------------
tensor([[0.1000, 0.2000, 0.5000]])
tensor([[0.3000],
[0.5000]])
(4):argmax函数的用法
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(y_hat.argmax(dim=0))
print(y_hat.argmax(dim=1))
#-----------------------
tensor([1, 0, 0])
tensor([2, 2])
https://blog.csdn.net/qiu931110/article/details/104292178