import torch
batch_size = 2
sequence_len = 3
hidden_dim = 5
x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
index=torch.LongTensor([[[2],[2],[1]],[[1],[2],[3]]]),
value=1)
print(x)
x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
index=torch.LongTensor([[[2],[2],[0]],[[1],[2],[3]]]),
value=2)
print(x)
x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
index=torch.LongTensor([[[2],[3],[1]],[[1],[2],[3]]]),
value=2)
print(x)
print结果:
tensor([[[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.]],
[[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.]]])
tensor([[[0., 0., 2., 0., 0.],
[0., 0., 2., 0., 0.],
[2., 0., 0., 0., 0.]],
[[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.]]])
tensor([[[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 2., 0., 0., 0.]],
[[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.]]])