K近邻算法 (KNN)和最远点采样(FPS)实现--python+pytorch

KNN算法实现--python+pytorch

  • K近邻算法 (KNN)
  • 最远点采样(FPS)

K近邻算法 (KNN)

主要思路:计算每个点和某点的距离,取距离最短的K个点的下标即可。
下面是个完整示例,代码复制即可运行

import torch
import time

#生成点集
def coordinate_gen(n):
    """
    生成n个三位点
    return tensor
    dim:n*3
    """
    xyz = torch.rand(size=(n,3))
    return xyz

#计算运行时间
def time_cost(f):

    def run_time(*args,**kwargs):
        start = time.time()
        res = f(*args,**kwargs)
        run_times = time.time()-start
        print("程序执行时间:%.6f s"%(run_times))
        return res

    return run_time

@time_cost
def knn(xyz,xyzs,k=3):
    """
    xyz:key point
    xyzs:all points
    找某点的k近邻个点
    return 近邻点的下标列表
    """
    idx = [0]*k
    distance = torch.sum((xyzs[:,:3]-xyz[:,:3])**2,dim=-1)
    for i in range(k):
        idx[i] = torch.argmin(distance,dim=0)
        distance[int(torch.argmin(distance,dim=0))] = float('inf')
    idx = [int(i) for i in idx]
    return idx

if __name__ == "__main__":
    print('-' * 20, '测试开始', '-' * 20)
    N,k = map(int,input("输入生成点数 和 k的值:").split())
    xyzs = coordinate_gen(N)
    xyz = torch.rand(size=(1,3))
    print("生成的点如下:\n",xyzs,"\n随机生成key point:",xyz)
    print(knn(xyz,xyzs=xyzs,k=k))
    print('-'*20,'测试结束','-'*20)

最远点采样(FPS)

def farthest_point_sample(data,npoints):
    """
    Args:
        data:输入的tensor张量,排列顺序 N,D
        Npoints: 需要的采样点

    Returns:data->采样点集组成的tensor,每行是一个采样点
    """
    N,D = data.shape #N是点数,D是维度
    xyz = data[:,:3] #只需要坐标
    centroids = torch.zeros(size=(npoints,)) #最终的采样点index
    dictance = torch.ones(size=(N,))*1e10 #距离列表,一开始设置的足够大,保证第一轮肯定能更新dictance
    farthest = torch.randint(low=0,high=N,size=(1,)) #随机选一个采样点的index
    for i in range(npoints):
        centroids[i] = farthest
        centroid = xyz[farthest,:]
        dict = ((xyz-centroid)**2).sum(dim=-1)
        mask = dict < dictance
        dictance[mask] = dict[mask]
        farthest = torch.argmax(dictance,dim=-1)
    print(centroids.type(torch.long))
    data= data[centroids.type(torch.long)]
    return data

你可能感兴趣的:(深度学习,pytorch,python,算法)