在官方文档中,主要是针对三维进行解释,但是它对数据索引的混乱使用,导致晦涩难懂。所以我主要针对2维的进行解释,更加通俗。
1、gather()
首先规定好。对于原始数据输入x(i,j)
对于输出y (m,n)
对于index(a,b)
dim=0,代表行,dim=1代表列
>>> x = torch.Tensor([[1,2],[3,4]])
1 2
3 4
>>> torch.gather(x, 1, torch.LongTensor([[0,0],[1,0]])
首先我们分析dim=1,所以就应该针对输入x的列进行操作
step1:我们先要求取y[0,0]的值,本来y[0,0]应该对应x[0,0],但是由于dim=1,所以应该要对列进行操作修改。所以用index的第一个值即index[0,0]=0,替换掉x的j值,所以x应该表示为
x【0,index[0,0]=0】=x[0,0]=1,所以y[0,0]=1
step2:这时需要求取y[0,1].同上y[0,1]应该对应x[0,1],列j值修改,用index[0,1]=0代替,此时x的索引被修改为x【0,index[0,1]】,即x[0,0],所以此时y[0,1]=x[0,0]=1
step3:y[1,0]应该对应x[1,0],但是索引j值被index[1,0]=1替换,所以x索引变为x[1,1]
所以此时的y[1,0]对应为x[1,1],等于4
step4:y[1,1]对应x[1,1],列j值被index[1,1]=0替换,x变为x[1,0]。所以y[1,1]=x[1,0]=3
2、sactter()和sactter_()
其实这俩差不多,所以先讲好理解的sactter()。
同理先规定好原始数据x(i,j)
对于输出y (m,n)
对于index(a,b)
dim=0,代表行,dim=1代表列
x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
y = torch.zeros(3, 5).scatter(dim=0, index=torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
>>> y
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
首先我们指导dim=0,所以是针对x的行进行操作
step1:这里我们先从x[0,0]=0.4319开始,即i=0,j=0.那么原本对应y:m=i=0,n=j=0,但是应该对x中的列i值进行修改,i应该等于index的第一个值,即i=index[0,0]=0.所以y也应该修改:m=i=0,n=j=0,,所以y[m=0,n=0]=x[0,0]=0.4319
step2:这里再从x[0,1]=0.6500开始,即i=0,j=1,那么原本对应y:m=i=0,n=j=1,但是i进行修改,i修改为index的第二个值,即i=index[0,1]=1,所以y对应的索引:m=i=1,n=j=1.所以此时
y[m=1,n=1]=x[0,1]=0.6500
step3:这里再看x[0,2]=0.4080开始,那么y的m=i=0,n=j=2.但i被修改为index[0,2]=2,所以此时的y的m=2,n=2。所以y[2,2]=x[0,2]=0.4080
step4:原理同上
sactter_()就相当于原地替换,实现的效果如下
假设原始数据为x
y = x.sactter()
x=y
最终是把原始数据更新一下