【pytorch工具记录】获取向量中值相同的元素的索引号

给定一个向量a,输出其中值相同的元素的索引号

比如给定[1,1,2,3,4,5,5,5,5],其中第0,1个元素的值都是1,要输出[0,1],第5,6,7,8个元素的值都是5,要输出[5,6,7,8],如果没有相同的,就输出元素自身的索引号。

def getIdx(a):
    co = a.unsqueeze(0)-a.unsqueeze(1)
    uniquer = co.unique(dim=0)
    out = []
    for r in uniquer:
        cover = torch.arange(a.size(0))
        mask = r==0
        idx = cover[mask]
        out.append(idx)
    return out

测试:

import torch
a = torch.Tensor([1,1,2,3,4,5,5,5,5])
idxs=getIdx(a) 
#output: [tensor([5, 6, 7, 8]), tensor([4]), tensor([3]), tensor([2]), tensor([0, 1])]

 

你可能感兴趣的:(pytorch工具存档,pytorch)