pytorch gather() 、sactter()和sactter_()的详解

在官方文档中,主要是针对三维进行解释,但是它对数据索引的混乱使用,导致晦涩难懂。所以我主要针对2维的进行解释,更加通俗。

1、gather()

首先规定好。对于原始数据输入x(i,j)

对于输出y (m,n)

对于index(a,b)

dim=0,代表行,dim=1代表列

>>> x = torch.Tensor([[1,2],[3,4]])
1  2
3  4
>>> torch.gather(x, 1, torch.LongTensor([[0,0],[1,0]])

首先我们分析dim=1,所以就应该针对输入x的列进行操作

step1:我们先要求取y[0,0]的值,本来y[0,0]应该对应x[0,0],但是由于dim=1,所以应该要对列进行操作修改。所以用index的第一个值即index[0,0]=0,替换掉x的j值,所以x应该表示为

x【0,index[0,0]=0】=x[0,0]=1,所以y[0,0]=1

step2:这时需要求取y[0,1].同上y[0,1]应该对应x[0,1],列j值修改,用index[0,1]=0代替,此时x的索引被修改为x【0,index[0,1]】,即x[0,0],所以此时y[0,1]=x[0,0]=1

step3:y[1,0]应该对应x[1,0],但是索引j值被index[1,0]=1替换,所以x索引变为x[1,1]

所以此时的y[1,0]对应为x[1,1],等于4

step4:y[1,1]对应x[1,1],列j值被index[1,1]=0替换,x变为x[1,0]。所以y[1,1]=x[1,0]=3

2、sactter()和sactter_()

其实这俩差不多,所以先讲好理解的sactter()。

同理先规定好原始数据x(i,j)

对于输出y (m,n)

对于index(a,b)

dim=0,代表行,dim=1代表列

x = torch.rand(2, 5)
>>> x
 
 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]
 
y = torch.zeros(3, 5).scatter(dim=0, index=torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
>>> y

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029

首先我们指导dim=0,所以是针对x的行进行操作

step1:这里我们先从x[0,0]=0.4319开始,即i=0,j=0.那么原本对应y:m=i=0,n=j=0,但是应该对x中的列i值进行修改,i应该等于index的第一个值,即i=index[0,0]=0.所以y也应该修改:m=i=0,n=j=0,,所以y[m=0,n=0]=x[0,0]=0.4319

step2:这里再从x[0,1]=0.6500开始,即i=0,j=1,那么原本对应y:m=i=0,n=j=1,但是i进行修改,i修改为index的第二个值,即i=index[0,1]=1,所以y对应的索引:m=i=1,n=j=1.所以此时

y[m=1,n=1]=x[0,1]=0.6500

step3:这里再看x[0,2]=0.4080开始,那么y的m=i=0,n=j=2.但i被修改为index[0,2]=2,所以此时的y的m=2,n=2。所以y[2,2]=x[0,2]=0.4080

step4:原理同上

sactter_()就相当于原地替换,实现的效果如下

假设原始数据为x

y = x.sactter()
x=y

最终是把原始数据更新一下

你可能感兴趣的:(pytorch,python)