Pytorch学习(二十九) ------------torch.split与torch.roll结合进行的一种高端操作

How to reverse a multi-hot encoding ?

multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1]])
(multihot_batch == torch.tensor(1)).nonzero()
tensor([[0, 1],
        [0, 3],
        [1, 3],
        [2, 2],
        [2, 3]])

希望得到

[[1, 3],
[3],
[2, 3]]

其实像这种任务,首先,你看到,得到的东西是 一个list,里面含有不同的size的tensor。那么显然用torch.split,该函数可以把tensor分成不同size的tensor。

其次,我们需要知道从每块切多大的size,所以像我们可以利用torch.roll,通过类似x-torch.roll(x,1,1)之类的方法,得到index变化的位置,从而得到每一块的大小。对于这个问题,做法如下:

import torch

multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1]])
#multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1], [0,0,0,1]])
vnon = (multihot_batch == torch.tensor(1)).nonzero(as_tuple=False)
v0 = vnon[:,0]
v1 = vnon[:,1]

# 0-based index -> 1-based index
split_ind = ((torch.roll(v0, -1, 0) - v0) == 1).nonzero(as_tuple=False)[:,0] + 1
# the first index is excatly the split size of the first split
# the other splits(apart from the last one) can be obtained by this
split_size = torch.cat([split_ind[0].view(1),(torch.roll(split_ind, -1, 0) - split_ind)[:-1]])
# add the final split size
final_size = torch.tensor([torch.numel(v1) - torch.sum(split_size)])
split_size = torch.cat([split_size, final_size])

print(torch.split(v1, split_size.tolist()))

你可能感兴趣的:(PyTorch,pytorch)