pytorch 三维one-hot tensor的制作

import torch

batch_size = 2
sequence_len = 3
hidden_dim = 5
x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
                               index=torch.LongTensor([[[2],[2],[1]],[[1],[2],[3]]]),
                               value=1)

print(x)

x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
                               index=torch.LongTensor([[[2],[2],[0]],[[1],[2],[3]]]),
                               value=2)

print(x)

x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
                               index=torch.LongTensor([[[2],[3],[1]],[[1],[2],[3]]]),
                               value=2)

print(x)

print结果:

tensor([[[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.]]])
         
         
tensor([[[0., 0., 2., 0., 0.],
         [0., 0., 2., 0., 0.],
         [2., 0., 0., 0., 0.]],

        [[0., 2., 0., 0., 0.],
         [0., 0., 2., 0., 0.],
         [0., 0., 0., 2., 0.]]])


tensor([[[0., 0., 2., 0., 0.],
         [0., 0., 0., 2., 0.],
         [0., 2., 0., 0., 0.]],

        [[0., 2., 0., 0., 0.],
         [0., 0., 2., 0., 0.],
         [0., 0., 0., 2., 0.]]])

你可能感兴趣的:(PyTorch)