pytorch之tensor按索引赋值,三种方法!

这里,我们介绍其中一种方法,即torch.scatter_()函数

import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)

b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a) 
#参数解释:‘1’:需要赋值的维度;‘torch.LongTensor(b)’:需要赋值的索引;‘a’:要赋的值
print("new_label: ",label)

输出:
 

label: 
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
new_label:  
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 0.]])

可以实现相同功能的函数还有:index_fill_(dim, index, val) ;index_put_(indices, value)

2.index_fill_(dim, index, val)
dim:要填充的维度
index:要填充的索引
val:要填充的值
与上面的用法类似

3.index_put_(indices, value)
indices:要填充的索引,与上面不同的是,这里直接使用的是要填充值的行和列
value:要填充的值

用法:

a = torch.zeros([5,5])
index = (torch.LongTensor([0,1]),torch.LongTensor([1,2])#生成索引
value = torch.Tensor([1,1]) #生成要填充的值
a.index_put_(index), value)

这就是上述三种方法,欢迎大家留言交流!

你可能感兴趣的:(pytorch,功能代码积累)