这里,我们介绍其中一种方法,即torch.scatter_()函数
import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)
b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a)
#参数解释:‘1’:需要赋值的维度;‘torch.LongTensor(b)’:需要赋值的索引;‘a’:要赋的值
print("new_label: ",label)
输出:
label:
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
new_label:
tensor([[1., 1., 1., 0., 0., 0.],
[1., 1., 0., 1., 0., 0.],
[0., 1., 1., 1., 0., 0.]])
可以实现相同功能的函数还有:index_fill_(dim, index, val) ;index_put_(indices, value)
2.index_fill_(dim, index, val)
dim:要填充的维度
index:要填充的索引
val:要填充的值
与上面的用法类似
3.index_put_(indices, value)
indices:要填充的索引,与上面不同的是,这里直接使用的是要填充值的行和列
value:要填充的值
用法:
a = torch.zeros([5,5])
index = (torch.LongTensor([0,1]),torch.LongTensor([1,2])#生成索引
value = torch.Tensor([1,1]) #生成要填充的值
a.index_put_(index), value)
这就是上述三种方法,欢迎大家留言交流!