pytorch scatter_函数

这里表示,如果是一个3维张量,当dim设置为0(行)的时候,src参数的张量形状与self参数的张量形状,在除dim=0以外的维度,需要大小相同。及第二维大小都为j,第三维大小都为k。

这里简单举一个二维张量的例子,当dim设置为0的时候,src参数的张量 要求与self参数的张量在列上的大小相同(dim=1)。当dim设置为1的时候,src参数的张量 要求与self参数的张量在行上的大小相同(dim=0)。

以下为jupyter的案例及输出:

#%%

# 函数scatter_(dim, index, src) → Tensor  从src中取index位置的元素,index元素值表示要写入self的位置

# 例子1:原始self是一个3行5列的随机张量,index是一个2行5列的张量,dim=0(行)

# dim=0(行),要求self的列和 index的列大小一致,index元素值表示self的行下标,因此不能超过self的行

index = torch.tensor([[0,1,2,0,0], [2,0,0,1,2]]) # 2行5列

src = torch.rand(2,5)

print(src)

z1 = torch.zeros(3,5).scatter_(0, index, src)

z1

# 例子2:原始self是一个2行4列的0张量,index是一个2行1列的张量,

# dim=1(列),要求self的行和 index的行大小一致,index元素值表示self的列下标,因此不能超过self的列

index2 = torch.tensor([[2], [3]])

index2.shape

z2 = torch.zeros(2,4).scatter_(1, index2,1.23)

z2

# 例子3:

# dim=0(行),要求self的列和 index的列大小一致,index元素值表示self的行下标,因此不能超过self的行

z3 = torch.zeros(2,4).scatter_(0, torch.tensor([[0,1,0,0]]),1.23)

z3

你可能感兴趣的:(pytorch scatter_函数)