Python实现Hash Heap数据结构

原文链接: http://hankerzheng.com/blog/Python-Hash-Heap

前言

LeetCode 480 Sliding Window Median 这题真是丧心病狂! 找中位数当然用Heap或者Balanced BST了, 然后Sliding Window当然需要支持Delete操作了. 好吧, 在Python中并没有自带的能同时支持两者的数据结构. 因此, 想在O(N logK)的时间复杂度里面实现Sliding Window Median, 只能自己写一个数据结构了. 本文实现了Python中的Hash Heap数据结构, 并且附带简单的测试函数.

My Implementation of Other Data Structure and Algorithm - HankerZheng’s GitHub

Implementation Thoughts

基本的实现和普通的Heap没啥区别, 就是需要一个额外的HashMap记录每个element的index, 然后每次做交换操作时, 不仅仅需要交换Heap中的位置, 还需要更新HashMap中的值.

需要考虑一下的是, 如何处理Duplicates. 我采用的方法新建一个HeapNode类, 该类能存当前元素的值以及相同值元素的个数, 这样在Heap中, 就不会有Duplicats, 那么记录index的HashMap就不会出现key冲突的问题了.

Source Code

# This is the Python implementation of Hash Heap based on the list implementation 
# of binary heap. The difference between Hash Heap and Binary Heap is that Hash
# Heap supports the `heapRemove` operation in O(log n) time and can check whether
# certain element is in the Hash Heap or not in O(1) time.
# 
# Basic automatic tests are given in `pushpopTest()` and `removeTest()`.
# Note: It may takes about 10 seconds to run both test functions.

import random

class HeapNode(object):
    """
    The node in the HashHeap to deal with duplicates.
    Each node store the value of each element and the number of duplicates
    with the same value.
    """
    def __init__(self, val, cnt):
        self.val = val
        self.cnt = cnt

    def __cmp__(self, other):
        return self.val - other.val

    def __str__(self):
        return "[%s, %d]" % (self.val, self.cnt)
    __repr__ = __str__

class HashHeap(object):
    """
    This HashHeap is the same as the list implementation of binary heap, but with
    a hashMap to map the value of one elemnt to its index in the list.
    """
    def __init__(self, arr):
        """
        `_cap` - the number of elements in the HashHeap
        `_maxIdx` - the max index of the binary heap
        `_data` - the list implementation of the binary heap
        `_hashMap` - mapping the element to its index in the binary heap
        """
        elemCnt = self._preProcess(arr)
        self._cap = len(arr)
        self._maxIdx = len(elemCnt) - 1
        self._data = [HeapNode(key, value) for key, value in elemCnt.items()]
        self._hashMap = {node.val: idx for idx, node in enumerate(self._data)}
        self._heapify()

    def _preProcess(self, arr):
        """
        Convert the input array into a dict object.
        The key to the dict is the value of the element.
        The value of the dict is the occurence of each element.
        """
        elemCnt = {}
        for elem in arr:
            elemCnt[elem] = elemCnt.get(elem, 0) + 1
        return elemCnt

    def _swap(self, idx1, idx2):
        """
        Swap the 2 elements in the heap.
        Also, change the index stored in `self._hashMap`
        """
        elem1, elem2 = self._data[idx1], self._data[idx2]
        self._hashMap[elem1.val] = idx2
        self._hashMap[elem2.val] = idx1
        self._data[idx1], self._data[idx2] = elem2, elem1

    def _heapify(self):
        idx = self._maxIdx
        while idx > 0:
            parentIdx = (idx - 1) / 2
            if self._data[parentIdx] > self._data[idx]:
                self._swap(parentIdx, idx)
                self._siftDown(idx)
            idx -= 1

    def _siftDown(self, idx):
        def heapValid(idx):
            left, right = idx * 2 + 1, idx * 2 + 2
            if left > self._maxIdx:
                return True
            if right > self._maxIdx:
                return self._data[idx] <= self._data[left]
            return self._data[idx] <= self._data[left] and self._data[idx] <= self._data[right]
        def smallerChild(idx):
            left, right = idx * 2 + 1, idx * 2 + 2
            if left > self._maxIdx:
                return None
            if right > self._maxIdx:
                return left
            return left if self._data[left] < self._data[right] else right

        current = idx
        while not heapValid(current):
            child = smallerChild(current)
            self._swap(current, child)
            current = child

    def _siftUp(self, idx):
        current = idx
        parent = (current - 1) / 2
        while current > 0 and self._data[parent] > self._data[current]:
            self._swap(parent, current)
            current = parent
            parent = (current - 1) / 2

    def _removeLastNode(self):
        rmNode = self._data.pop(-1)
        self._cap -= 1
        self._maxIdx -= 1
        self._hashMap.pop(rmNode.val)

    def _removeByIdx(self, idx):
        thisNode = self._data[idx]
        retVal = thisNode.val
        if thisNode.cnt > 1:
            thisNode.cnt -= 1
            self._cap -= 1
        elif idx == self._maxIdx:
            # the node itself is the last node
            self._removeLastNode()
        else:
            self._swap(idx, self._maxIdx)
            self._removeLastNode()
            pidx = (idx - 1) / 2
            # check to see we should sift up or sift down
            if pidx >= 0 and self._data[pidx] > self._data[idx]:
                self._siftUp(idx)
            else:
                self._siftDown(idx)
        return retVal

    @property
    def length(self):
        """
        Return the number of elements in the Hash Heap
        """
        return self._cap

    def heapPeep(self):
        """
        Return the MIN element in the Hash Heap
        """
        if not self._data:
            return float("inf")
        return self._data[0].val

    def heapPop(self):
        """
        Remove the MIN element from the Hash Heap and return its value
        """
        return self._removeByIdx(0)

    def heapPush(self, elem):
        """
        Push a new element into the Hash Heap
        """
        self._cap += 1
        if elem not in self._hashMap:
            self._maxIdx += 1
            self._data.append(HeapNode(elem, 1))
            self._hashMap[elem] = self._maxIdx
            self._siftUp(self._maxIdx)
        else:
            idx = self._hashMap[elem]
            self._data[idx].cnt += 1

    def heapRemove(self, elem):
        """
        Remove a existing element from the Hash Heap
        If the element to be removed is not in the Hash Heap, raise an error.
        """
        if elem not in self._hashMap:
            raise ValueError("Element to be removed is not in HashHeap!!!")
        idx = self._hashMap[elem]
        self._removeByIdx(idx)

    def __contains__(self, value):
        return value in self._hashMap

    def __str__(self):
        return "%s" % [elem.val for elem in self._data]
    __repr__ = __str__


def pushpopTest():
    """
    Randomly generate a list, and push each element into the heap.
    Test HeapPush by comparing the first element in the heap with the 
    smallest element in the List.
    Test HeapPop by comparing the popped element from the heap with the
    sorted list one by one. 
    """
    for _ in xrange(100):
        thisHeap = HashHeap([0])
        testList = [0]
        for i in xrange(1000):
            thisRandom = random.randrange(-100, 100000)
            thisHeap.heapPush(thisRandom)
            testList.append(thisRandom)
            assert min(testList) == thisHeap.heapPeep()
            assert len(testList) == thisHeap.length
            assert len(thisHeap._hashMap) == thisHeap._maxIdx + 1
        testList.sort()
        assert len(testList) == thisHeap.length
        for idx, num in enumerate(testList):
            assert num == thisHeap.heapPop()
            assert len(testList) - 1 - idx == thisHeap.length
            assert len(thisHeap._hashMap) == thisHeap._maxIdx + 1

def removeTest():
    """
    Randomly generate a list, and push each element into the heap.
    Test HeapRemove by randomly delete one element from the heap by the probability
    of 0.2, and then check whether the first element in the heap is the same as the
    smallest element in the list.
    """
    for _ in xrange(100):
        thisHeap = HashHeap([0])
        testList = [0]
        for i in xrange(1000):
            thisRandom = random.randrange(-100, 100000)
            thisHeap.heapPush(thisRandom)
            if random.random() < 0.2:
                thisHeap.heapRemove(thisRandom)
            else:
                testList.append(thisRandom)
            assert min(testList) == thisHeap.heapPeep()
            assert len(testList) == thisHeap.length
            assert len(thisHeap._hashMap) == thisHeap._maxIdx + 1
        testList.sort()
        assert len(testList) == thisHeap.length
        for idx, num in enumerate(testList):
            assert num == thisHeap.heapPop()
            assert len(testList) - 1 - idx == thisHeap.length
            assert len(thisHeap._hashMap) == thisHeap._maxIdx + 1


if __name__ == '__main__':
    pushpopTest()
    removeTest()

你可能感兴趣的:(刷题,学习笔记)