从小顶堆到堆排序——超详细图解——Python3实现

文章目录

  • 前言
  • 最小堆实现
    • 已知所有节点,原地构建最小堆
    • 最小堆删除顶点
    • 最小堆添加节点
    • 实时插入删除
  • 堆排序
  • 总结

前言

在简单选择排序中,每次选择会从待排序元素中找到最小值,但每次选择都需要遍历完剩余所有元素,而且在遍历时没有记录起来有用信息,这显得很浪费。

而堆排序则利用了最小堆(或最大堆)的性质,每次选择最小值都会利用堆的数据结构来保存有用信息,即总是使得整个堆是一个最小堆,以便下一次选择时直接选择索引为0的节点就可以了。

注意,本文使用到一个打印完全二叉树的算法方便我们观察整个堆的样子。具体做法是,新建一个printHeap.py,把这个算法除了测试代码都放进去。(不要误会,我绝对不是因为嫌画图麻烦才来写这个算法的~)(PS:有了这个算法,超详细图解不是梦,因为我可以在每一步打印出堆的样子)

最小堆实现

堆是一种完全二叉树,即除了最后一层外每层的节点都被放满了。而最小堆则有这种性质:每个节点都小于等于它的左右孩子的值(包括它的左右子树的所有节点的值)。

一般我们用数组作为堆的实际存储,因为父节点和子节点之间有这样的关系我们可以利用(节点编号从0开始):

  • arr[i] 的左孩子是arr[2i+1],右孩子是arr[2i+2]。
    • 对于最小堆来说,arr[i] <= arr[2i+1]arr[i] <= arr[2i+2]都成立。
  • 最后一个非叶子节点(或者说最后一个叶子节点的父节点)的索引为arr.length/2-1

已知所有节点,原地构建最小堆

这种情况我们直接把所有节点放到一个数组里。在构建时,我们基于这样一种规则:对于每一颗子树,如果它的根节点的左子树和右子树都已经是最小堆了,那么只需要将根节点冒泡下移到合适的位置(或者根本不需要操作,因为它已经是最小堆),就可以使得该子树成为一个最小堆。

从这个规则,我们可以得知,需要从最低层开始处理,直到最高层。因为处理完下面一层后,再去处理上面一层时,以这层的各个节点为根节点的每个子树已经满足了“左子树和右子树是最小堆”的性质。本来应该从最后一个节点开始处理,但如果是叶子节点的话,它自成一个最小堆,所以我们就直接从最后一个非叶子节点开始处理。

我们以数据[13,20,5,3,7,16,24,16]为例,它的完全二叉树的样子如下(n代表没有节点):

              13              
       /              \       
      20               5      
   /      \         /     \   
  24       16      3       7  
 /  \     / \     / \     / \ 
16   n   n   n   n   n   n   n

整个处理过程如下:

开始时整个堆的样子:
              13              
       /              \       
      20               5      
   /      \         /     \   
  24       16      3       7  
 /  \     / \     / \     / \ 
16   n   n   n   n   n   n   n
处理索引为3的根节点的子树,使其变成小顶堆。就是数据为24的那个节点。
可见它下移到了左孩子处,然后到达了叶子节点,不需要继续下移了。
              13              
       /              \       
      20               5      
   /      \         /     \   
  16       16      3       7  
 /  \     / \     / \     / \ 
24   n   n   n   n   n   n   n
处理索引为2的根节点的子树,使其变成小顶堆。就是数据为5的那个节点。
可见它下移到了左孩子处,然后到达了叶子节点,不需要继续下移了。
              13              
       /              \       
      20               3      
   /      \         /     \   
  16       16      5       7  
 /  \     / \     / \     / \ 
24   n   n   n   n   n   n   n
处理索引为1的根节点的子树,使其变成小顶堆。就是数据为20的那个节点。
可见它下移到了左孩子处,之后发现不需要下移了即使还有孩子节点,因为整棵子树已经是最小堆了。
              13              
       /              \       
      16               3      
   /      \         /     \   
  20       16      5       7  
 /  \     / \     / \     / \ 
24   n   n   n   n   n   n   n
处理索引为0的根节点的子树,使其变成小顶堆。就是数据为13的那个节点。
它下移到了右孩子处。
              3               
       /              \       
      16               13     
   /      \         /     \   
  20       16      5       7  
 /  \     / \     / \     / \ 
24   n   n   n   n   n   n   n
它下移到了左孩子处,到达叶子节点。
              3               
       /              \       
      16               5      
   /      \         /     \   
  20       16      13      7  
 /  \     / \     / \     / \ 
24   n   n   n   n   n   n   n

最终结果为:
              3               
       /              \       
      16               5      
   /      \         /     \   
  20       16      13      7  
 /  \     / \     / \     / \ 
24   n   n   n   n   n   n   n
可见整颗树已经符合最小堆的定义了。

整个算法的实现如下,思路如之前介绍的一样。从最后一个非叶子节点开始处理,直到root节点,遍历过程中,对遍历节点为根的每颗子树进行处理,使得变成最小堆。当然,我们利用了“左子树和右子树是最小堆”的性质。

import printHeap

data = [13,20,5,24,16,3,7,16]

length = len(data)
lastNonLeaf = int(len(data)/2) - 1  #最后一个非叶子节点的索引

print("开始时整个堆的样子:")
printHeap.printHeap(data, length)

def exchange(List, i, j):
    temp = List[i]
    List[i] = List[j]
    List[j] = temp

for i in range(lastNonLeaf, -1, -1):  #遍历范围[lastNonLeaf, 0]
    j = i  # 用j暂存每个需要下沉的非叶子节点索引
    print("处理索引为%d的根节点的子树,使其变成小顶堆"%i)
    while(True):  #j保证是合法的,因为更新j时是从合法的k来的,不用写成j < length
        k = 2*j+1   #k用来保存两个孩子较小值的索引
        if k >= length:  #因为是完全二叉树,先检查左孩子是否存在
            break

        #如果左孩子存在,那么去两个孩子的较小值,如果右孩子也存在
        right = 2*j+2
        if right < length and data[k] > data[right]:
            k = right

        #如果当前节点比孩子节点较小值还要小
        if data[j] > data[k]:
            exchange(data, j, k)  #交换当前节点和较小孩子
            j = k                 #当前节点下沉到这个孩子的索引
            printHeap.printHeap(data, length)
        else:
            break

print("\n最终结果为:")
printHeap.printHeap(data, length)

然后把下沉逻辑封装成函数:

data = [13,20,5,3,7,16,24,16]

length = len(data)
lastNonLeaf = int(len(data)/2) - 1

def exchange(List, i, j):
    temp = List[i]
    List[i] = List[j]
    List[j] = temp

def shiftDown(j):
    while(True): 
        k = 2*j+1  
        if k >= length:  
            break

        right = 2*j+2
        if right < length and data[k] > data[right]:
            k = right

        if data[j] > data[k]:
            exchange(data, j, k)  
            j = k                 
        else:
            break

for i in range(lastNonLeaf, -1, -1):
    shiftDown(i)
    
print(data)

最小堆删除顶点

构建完最小堆的数组,其索引0元素就是最小元素了。现在封装一个删除顶点的函数,这样我们每次调用这个函数就能得到数组中的最小值。

具体做法是,临时变量保存data[0]结束前返回,将最后一个叶子节点替换到根节点(这相当于删除了根节点),数组大小减一,然后下沉根节点到合适的位置后,整个堆又变成了最小堆。当然,这里我们又利用了“左子树和右子树是最小堆”的性质,因为刚替换根节点时,它的左右子树已经是最小堆了。

import printHeap

data = [13,20,5,3,7,16,24,16]

length = len(data)
lastNonLeaf = int(len(data)/2) - 1

def exchange(List, i, j):
    temp = List[i]
    List[i] = List[j]
    List[j] = temp

def shiftDown(j):
    while(True): 
        k = 2*j+1  
        if k >= length:  
            break

        right = 2*j+2
        if right < length and data[k] > data[right]:
            k = right

        if data[j] > data[k]:
            exchange(data, j, k)  
            j = k                 
        else:
            break

for i in range(lastNonLeaf, -1, -1):
    shiftDown(i)
    
printHeap.printHeap(data, len(data))
print()

def deletePeak():
    global length, data
    lastIndex = length - 1 
    re = data[0]  #保存顶点值
    data[0] = data[lastIndex]  #将最后一个节点替换到根节点
    shiftDown(0)  #对根节点的这棵树进行下沉
    length -= 1   #大小减一
    data = data[:-1]  #数组去掉最后元素(实际上这句不加也可以,但如果同时有插入删除,这句必须加)
    return re

for i in range(len(data)):
    print("peak is ", deletePeak())
    printHeap.printHeap(data, length)
    print()

打印效果如下:

              3               
       /              \       
      7                5      
   /      \         /     \   
  16       13      16      24 
 /  \     / \     / \     / \ 
20   n   n   n   n   n   n   n

peak is  3
       5         
   /       \     
  7         16   
 /  \      /  \  
16   13   20   24

peak is  5
       7        
   /       \    
  13        16  
 /  \      /  \ 
16   24   20   n

peak is  7
       13      
   /       \   
  16        16 
 /  \      / \ 
20   24   n   n

peak is  13
      16      
   /      \   
  20       16 
 /  \     / \ 
24   n   n   n

peak is  16
  16   
 /  \  
20   24

peak is  16
  20  
 /  \ 
24   n

peak is  20
24

peak is  24

从打印效果可以很方便地看到,每次删除堆顶后,重新构建的新堆都已经是最小堆了。而且,最重要的是,每次删除的堆顶累积起来,刚好构建了升序排序。这种方式我们可以得到排序结果,以记录删除堆顶的方式。(这其实就是堆排序的本质了)

甚至可以得到各种变种算法题的题解:

  • 最小的k个数:for循环执行k次
  • 第k小的数:for循环执行第k次的结果

最小堆添加节点

添加节点时,新节点添加到完全二叉树的下一个叶子节点的位置上去,并且由于整个堆之前已经是最小堆了,所以只可能影响从新节点到根节点的路径上的各个节点,以这些节点为root的子树可能暂时不是最小堆了。我们通过冒泡上移的方式,来使得整个堆恢复最小堆。

import printHeap

data = [13,20,5,3,7,16,24,16]

length = len(data)
lastNonLeaf = int(len(data)/2) - 1

def exchange(List, i, j):
    temp = List[i]
    List[i] = List[j]
    List[j] = temp

def shiftDown(j):
    while(True): 
        k = 2*j+1  
        if k >= length:  
            break

        right = 2*j+2
        if right < length and data[k] > data[right]:
            k = right

        if data[j] > data[k]:
            exchange(data, j, k)  
            j = k                 
        else:
            break

for i in range(lastNonLeaf, -1, -1):
    shiftDown(i)
    
print("开始时整个堆的样子:")
printHeap.printHeap(data, len(data))
print()

def addToHeap(item):
    global length
    data.append(item)
    length += 1
    upIndex = length - 1
    while(upIndex > 0):#为0代表根节点,此时不需要冒泡上移了
        parent = (upIndex - 1)>>1#左右孩子的父节点索引,都可以得到
        if data[upIndex] < data[parent]:
            exchange(data, upIndex, parent)#冒泡上移
            upIndex = parent#更新遍历索引
        else:
            break


for i in range(1,10):
    print("add item is", i)
    addToHeap(i)
    printHeap.printHeap(data, length)
    print()
    
print("\n最终结果为:")
printHeap.printHeap(data, length)

打印效果:

开始时整个堆的样子:
              3               
       /              \       
      7                5      
   /      \         /     \   
  16       13      16      24 
 /  \     / \     / \     / \ 
20   n   n   n   n   n   n   n

add item is 1。数据为1的节点虽然最开始加到了最后一层的第二个节点的位置,
但最终冒泡上移到了根节点的位置,才使得整个堆变成最小堆。
               1               
        /              \       
       3                5      
   /       \         /     \   
  7         13      16      24 
 /  \      / \     / \     / \ 
20   16   n   n   n   n   n   n

add item is 2
                1               
        /               \       
       2                 5      
   /       \          /     \   
  7         3        16      24 
 /  \      /  \     / \     / \ 
20   16   13   n   n   n   n   n

add item is 3
                1               
        /               \       
       2                 5      
   /       \          /     \   
  7         3        16      24 
 /  \      /  \     / \     / \ 
20   16   13   3   n   n   n   n

add item is 4
                1                
        /               \        
       2                 4       
   /       \          /      \   
  7         3        5        24 
 /  \      /  \     /  \     / \ 
20   16   13   3   16   n   n   n

add item is 5
                1                
        /               \        
       2                 4       
   /       \          /      \   
  7         3        5        24 
 /  \      /  \     /  \     / \ 
20   16   13   3   16   5   n   n

add item is 6
                1                 
        /               \         
       2                 4        
   /       \          /      \    
  7         3        5        6   
 /  \      /  \     /  \     /  \ 
20   16   13   3   16   5   24   n

add item is 7
                1                 
        /               \         
       2                 4        
   /       \          /      \    
  7         3        5        6   
 /  \      /  \     /  \     /  \ 
20   16   13   3   16   5   24   7

add item is 8
                              1                               
               /                              \               
              2                                4              
       /              \                 /             \       
      7                3               5               6      
   /      \         /     \         /     \         /     \   
  8        16      13      3       16      5       24      7  
 /  \     / \     / \     / \     / \     / \     / \     / \ 
20   n   n   n   n   n   n   n   n   n   n   n   n   n   n   n

add item is 9
                              1                               
               /                              \               
              2                                4              
       /              \                 /             \       
      7                3               5               6      
   /      \         /     \         /     \         /     \   
  8        16      13      3       16      5       24      7  
 /  \     / \     / \     / \     / \     / \     / \     / \ 
20   9   n   n   n   n   n   n   n   n   n   n   n   n   n   n


最终结果为:
                              1                               
               /                              \               
              2                                4              
       /              \                 /             \       
      7                3               5               6      
   /      \         /     \         /     \         /     \   
  8        16      13      3       16      5       24      7  
 /  \     / \     / \     / \     / \     / \     / \     / \ 
20   9   n   n   n   n   n   n   n   n   n   n   n   n   n   n

实时插入删除

把插入和删除结合,写了一个可交互的,可实时观察的程序。

import printHeap

data = [13,20,5,3,7,16,24,16]

length = len(data)
lastNonLeaf = int(len(data)/2) - 1

def exchange(List, i, j):
    temp = List[i]
    List[i] = List[j]
    List[j] = temp

def shiftDown(j):
    while(True): 
        k = 2*j+1  
        if k >= length:  
            break

        right = 2*j+2
        if right < length and data[k] > data[right]:
            k = right

        if data[j] > data[k]:
            exchange(data, j, k)  
            j = k                 
        else:
            break

for i in range(lastNonLeaf, -1, -1):
    shiftDown(i)
    
print("开始时整个堆的样子:")
printHeap.printHeap(data, len(data))
print()

def deletePeak():
    global length, data
    lastIndex = length - 1 
    re = data[0]  #保存顶点值
    data[0] = data[lastIndex]  #将最后一个节点替换到根节点
    shiftDown(0)  #对根节点的这棵树进行下沉
    length -= 1   #大小减一
    data = data[:-1]  #数组去掉最后元素(实际上这句不加也可以,但如果同时有插入删除,这句必须加)
    return re

def addToHeap(item):
    global length
    data.append(item)
    length += 1
    upIndex = length - 1
    while(upIndex > 0):#为0代表根节点,此时不需要冒泡上移了
        parent = (upIndex - 1)>>1#左右孩子的父节点索引,都可以得到
        if data[upIndex] < data[parent]:
            exchange(data, upIndex, parent)#冒泡上移
            upIndex = parent#更新遍历索引
        else:
            break


while(True):
    want = input("如果想删除请输入d,想增加请输入a空格和你想增加的数字\n")
    if want == "d":
        print("删除顶点值为", deletePeak())
        printHeap.printHeap(data, length)
        print()
    elif want[0] == "a" and len(want.split(" ")) == 2 and want.split(" ")[1].isdigit():
        print("添加节点值为", want.split(" ")[1] )
        addToHeap(int(want.split(" ")[1]))
        printHeap.printHeap(data, length)
        print()
    else:
        print("请按照规则输入")

打印效果如下:

开始时整个堆的样子:
              3               
       /              \       
      7                5      
   /      \         /     \   
  16       13      16      24 
 /  \     / \     / \     / \ 
20   n   n   n   n   n   n   n

如果想删除请输入d,想增加请输入a空格和你想增加的数字
d
删除顶点值为 3
       5         
   /       \     
  7         16   
 /  \      /  \  
16   13   20   24

如果想删除请输入d,想增加请输入a空格和你想增加的数字
a 1
添加节点值为 1
              1               
       /              \       
      5                16     
   /      \         /     \   
  7        13      20      24 
 /  \     / \     / \     / \ 
16   n   n   n   n   n   n   n

自己玩去吧。

堆排序

其实前面的最小堆删除顶点章节已经提及了,这里再次明确下堆排序的步骤:

  1. 已知所有节点,原地构建最小堆
  2. 执行删除顶点动作,重新构建最小堆。
  3. 执行删除顶点动作时,如果不执行data = data[:-1]的话,后面会有一个空位,把删除顶点值放到空位上。
  4. 不停执行2、3步骤,直到所有节点都被删除了一次。

下面示例就解释了,删除顶点动作时,后面会有一个空位。

       5         
   /       \     
  7         16   
 /  \      /  \  
16   13   20   24

peak is  5。删除数据为5的节点后,后面留有一个空位。
       7        
   /       \    
  13        16  
 /  \      /  \ 
16   24   20   n

前面已经说过,累积删除最小堆的顶点值,可以得到一个升序的排序。但前提是,你得新开辟一个同样大小的数组,然后每删除一个节点就依次放置(从新数组的0索引开始)。

但为了复用空间,每次删除我们将删除的顶点值放到空位上去,这样:

  1. 第一次放了个最小值,放到了整个二叉树的最后一个节点上,也就是数组的最后位置。
  2. 第二次放了个次小值,放到了整个二叉树的倒数第二节点上,也就是数组的倒数第二个位置。
  3. 不断重复…

但注意,由于最开始构建的是最小堆,最后整个数组会变成降序排序。

如果你想要最终整个数组是升序排序,则最开始需要构建最大堆。

最小堆删除顶点章节稍加改动即可。

import printHeap

data = [13,20,5,3,7,16,24,16]

length = len(data)
lastNonLeaf = int(len(data)/2) - 1

def exchange(List, i, j):
    temp = List[i]
    List[i] = List[j]
    List[j] = temp

def shiftDown(j):
    while(True): 
        k = 2*j+1  
        if k >= length:  
            break

        right = 2*j+2
        if right < length and data[k] > data[right]:
            k = right

        if data[j] > data[k]:
            exchange(data, j, k)  
            j = k                 
        else:
            break

for i in range(lastNonLeaf, -1, -1):
    shiftDown(i)
    
printHeap.printHeap(data, len(data))
print()

def deletePeak():
    global length, data
    lastIndex = length - 1 
    re = data[0]  
    data[0] = data[lastIndex] 
    shiftDown(0)  
    length -= 1  
    #data = data[:-1] 这句不能使用,因为要留有空位
    data[lastIndex] = re  #将返回值放到空位上去
    return re

for i in range(len(data)):
    print("peak is ", deletePeak())

print(data)

打印结果如下:

              3               
       /              \       
      7                5      
   /      \         /     \   
  16       13      16      24 
 /  \     / \     / \     / \ 
20   n   n   n   n   n   n   n

peak is  3
peak is  5
peak is  7
peak is  13
peak is  16
peak is  16
peak is  20
peak is  24
[24, 20, 16, 16, 13, 7, 5, 3]

总结

重点在于理解最小堆的构建原理:对于每一颗子树,如果它的根节点的左子树和右子树都已经是最小堆了,那么只需要将根节点冒泡下移到合适的位置(或者根本不需要操作,因为它已经是最小堆),就可以使得该子树成为一个最小堆。

在堆排序中的过程中,并不是直接得到整个排序结果,而是每执行一次删除顶点动作,才得到一个最小值或次小值。

你可能感兴趣的:(数据结构与算法,最小堆,堆排序,小顶推)