python实现KD树模型

import numpy as np


class KDTree():
    def __init__(self, obj):
        self.key = obj
        self.lchild = None
        self.rchild = None

    def addlChild(self, obj):
        if self.lchild == None:
            self.lchild = obj
        else:
            t = KDTree(obj)
            t.lchild = self.lchild
            self.lchild = t

    def addrChild(self, obj):
        if self.rchild == None:
            self.rchild = obj
        else:
            t = KDTree(obj)
            t.rchild = self.rchild
            self.rchild = t

    def getRootVal(self):
        return self.key

    def splitLR(self, root, left, right):
        self.key = root

    def buildKD(self, data, depth):
        dataNum = data.__len__()
        if dataNum == 0:
            return
        else:
            self.key = KDTree(data[0])
        numAxis = 2
        splitAxis = depth % numAxis
        lchilds = []
        rchilds = []
        for i in range(dataNum):
            if data[i][splitAxis] == self.calMedian([x[splitAxis] for x in data]):  # 根
                self.key = data[i]
            else:
                if data[i][splitAxis] < self.calMedian([x[splitAxis] for x in data]):
                    lchilds.append(data[i])
                else:
                    rchilds.append(data[i])
        self.lchild = KDTree(' ')
        self.rchild = KDTree(' ')
        self.lchild.buildKD(lchilds, depth + 1)
        self.rchild.buildKD(rchilds, depth + 1)

    def calMedian(self, data):
        numElem = data.__len__() / 2
        data = np.sort(data)
        return data[numElem]

pass


def printTree(KDTree):
    if ' ' == KDTree.lchild.key and ' ' == KDTree.rchild.key:
        print KDTree.key
    else:
        printTree(KDTree.lchild)
        printTree(KDTree.rchild)

if __name__ == "__main__":
    data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
    tree = KDTree(' ')
    tree.buildKD(data, 0)
    printTree(tree)

你可能感兴趣的:(python实现KD树模型)