想用红黑树,怎么搜都搜不到现成的Python实现。干脆自己写一个。
算法的结构按照Sedgewick的《算法(4th)》一书第三章写成,略有改动。
完整的API实现,也用了一些测试case,暂时没发现问题。
这玩意就是好用,谁用谁知道。
废话不多说直接上代码。
#注意事项: 重载RBT.Node.reduce(self,new_val)来实现 append()方法中对已存在主键对应值的自定义合并操作。默认为调用list.extend()方法。
1 #!/usr/bin/env python 2 #coding: gbk 3 4 ######################################################################## 5 #Author: Feng Ruohang 6 #Create: 2014/10/06 11:38 7 #Digest: Provide a common data struct: Red Black Tree 8 ######################################################################## 9 10 class RBT(object): 11 class Node(object): 12 ''' 13 Node used in RBTree 14 ''' 15 def __init__(self,key,value=None,color=False,N=1): 16 self.key = key 17 self.val = value 18 self.color = color #False for Black, True for Red 19 self.N = N #Total numbers of nodes in this subtree 20 self.left = None 21 self.right = None 22 23 def __cmp__(l,r): 24 return cmp(l.key,r.key) 25 26 def __eq__(l,r): 27 return True if l.key == r.key else False 28 29 def __add__(l,r): 30 l.value + r.value 31 32 def reduce(self,new_val): 33 self.val.extend(new_val) 34 35 def __init__(self): 36 self.root = None 37 38 #====================APIs====================# 39 #=====Basic API 40 41 def get(self,key): 42 return self.__get(self.root,key) 43 44 def put(self,key,val): 45 self.root = self.__put(self.root,key,val) 46 self.root.color = False 47 48 def append(self,key,val): 49 self.root = self.__append(self.root,key,val) 50 self.root.color = False 51 52 def delete(self,key): 53 if not self.contains(key): 54 raise LookupError('No such keys in rbtree. Fail to Delete') 55 if self.__is_black(self.root.left) and self.__is_black(self.root.right): 56 self.root.color = True 57 self.root = self.__delete(self.root,key) 58 if not self.is_empty(): 59 self.root.color = False 60 61 def del_min(self): 62 if self.is_empty(): 63 raise LookupError('Empty Red-Black Tree. Can\'t delete min') 64 if self.__is_black(self.root.left) and self.__is_black(self.root.right): 65 self.root.color = True 66 self.root = self.__del_min(self.root) 67 if not self.is_empty(): self.root.color = False 68 69 def del_max(self): 70 if self.is_empty(): 71 raise LookupError('Empty Red-Black Tree. Can\'t delete max') 72 if self.__is_black(self.root.left) and self.__is_black(self.root.right): 73 self.root.color = True 74 self.root = self.__del_max(self.root) 75 if not self.is_empty(): self.root.color = False 76 77 def size(self): 78 return self.__size(self.root) 79 80 def is_empty(self): 81 return not self.root 82 83 def contains(self,key): 84 return bool(self.get(key)) 85 86 #=====Advance API 87 def min(self): 88 if self.is_empty(): 89 return None 90 return self.__min(self.root).key 91 92 def max(self): 93 if self.is_empty(): 94 return None 95 return self.__max(self.root).key 96 97 def floor(self,key): 98 x = self.__floor(self.root,key) 99 if x: 100 return x.key,x.val 101 else: 102 return None,None 103 104 def ceil(self,key): 105 x = self.__ceil(self.root,key) 106 if x: 107 return x.key,x.val 108 else: 109 return None,None 110 111 def below(self,key): 112 index = self.index(key) 113 if not 0 <= index - 1 < self.size(): 114 return None,None #Return None if out of range 115 x = self.__select(self.root,index - 1) 116 return x.key,x.val 117 118 def above(self,key): 119 index = self.index(key) 120 if self.contains(key): 121 if not 0 <= index + 1 < self.size(): 122 return None,None #Return None if out of range 123 else: 124 x = self.__select(self.root,index+1) 125 return x.key,x.val 126 else:#if key is not in tree. then select(i) is what we need 127 if not 0 <= index < self.size(): 128 return None,None #Return None if out of range 129 else: 130 x = self.__select(self.root,index) 131 return x.key,x.val 132 133 def index(self,key): 134 return self.__index(self.root,key) 135 136 def keys(self): 137 '''Return All Keys in the tree ''' 138 return self.range(self.min(),self.max()) 139 140 def range(self,lo,hi): 141 '''Take two keys. return keys between them''' 142 q = [] 143 self.__range(self.root,q,lo,hi) 144 return q 145 146 def select(self,index): 147 '''Given Index Return Corresponding key ''' 148 if not 0 <= index < self.size(): 149 return None 150 return self.__select(self.root,index).key 151 152 def width(self,lo,hi): 153 '''Return the numbers of keys between lo and hi ''' 154 if lo > hi: 155 return 0 156 if self.contains(hi): 157 return self.index(hi) - self.index(lo) + 1 158 else: 159 return self.index(hi) - self.index(lo) 160 161 162 #===============Private Method===============# 163 #=====Basic 164 def __get(self,x,key): 165 while x: 166 tag = cmp(key,x.key) 167 if tag < 0 : x = x.left 168 elif tag > 0 :x = x.right 169 else: return x.val 170 171 def __put(self,h,key,val): 172 if not h: 173 return self.Node(key,val,True,1) 174 tag = cmp(key,h.key) 175 if tag < 0: 176 h.left = self.__put(h.left,key,val) 177 elif tag > 0: 178 h.right = self.__put(h.right,key,val) 179 else: 180 h.val = val #Update 181 182 if self.__is_black(h.left) and self.__is_red(h.right): 183 h = self.__rotate_left(h) 184 if self.__is_red(h.left) and self.__is_red(h.left.left): 185 h = self.__rotate_right(h) 186 if self.__is_red(h.left) and self.__is_red(h.right): 187 self.__flip_colors(h) 188 h.N = self.__size(h.left) + self.__size(h.right) + 1 189 return h 190 191 def __append(self,h,key,val): 192 if not h: 193 return self.Node(key,val,True,1) 194 tag = cmp(key,h.key) 195 if tag < 0: 196 h.left = self.__append(h.left,key,val) 197 elif tag > 0: 198 h.right = self.__append(h.right,key,val) 199 else: 200 h.reduce(val) #append. 201 202 if self.__is_black(h.left) and self.__is_red(h.right): 203 h = self.__rotate_left(h) 204 if self.__is_red(h.left) and self.__is_red(h.left.left): 205 h = self.__rotate_right(h) 206 if self.__is_red(h.left) and self.__is_red(h.right): 207 self.__flip_colors(h) 208 h.N = self.__size(h.left) + self.__size(h.right) + 1 209 return h 210 211 def __del_min(self,h): 212 if not h.left: #if h is empty:return None 213 return None 214 215 if self.__is_black(h.left) and self.__is_black(h.left.left): 216 self.__move_red_left(h) 217 h.left = self.__del_min(h.left) #Del recursive 218 return self.__balance(h) 219 220 def __del_max(self,h): 221 if self.__is_red(h.left): 222 h = self.__rotate_right(h) 223 if not h.right: 224 return None 225 if self.__is_black(h.right) and self.__is_black(h.right.left): 226 h = self.__move_red_right(h) 227 h.right = self.__del_max(h.right) 228 return self.__balance(h) 229 230 def __delete(self,h,key): 231 if key < h.key: 232 if self.__is_black(h.left) and self.__is_black(h.left.left): 233 h = self.__move_red_left(h) 234 h.left = self.__delete(h.left,key) 235 else: 236 if self.__is_red(h.left): 237 h = self.__rotate_right(h) 238 if key == h.key and not h.right: 239 return None 240 if self.__is_black(h.right) and self.__is_black(h.right.left): 241 h = self.__move_red_right(h) 242 if key == h.key:#replace h with min of right subtree 243 x = self.__min(h.right) 244 h.key = x.key 245 h.val = x.val 246 h.right = self.__del_min(h.right) 247 else: 248 h.right = self.__delete(h.right,key) 249 h = self.__balance(h) 250 return h 251 252 #=====Advance 253 def __min(self,h): 254 #Assume h is not null 255 if not h.left: 256 return h 257 else: 258 return self.__min(h.left) 259 260 def __max(self,h): 261 #Assume h is not null 262 if not h.right: 263 return h 264 else: 265 return self.__max(h.right) 266 267 def __floor(self,h,key): 268 '''Find the NODE with key <= given key in the tree rooted at h ''' 269 if not h: 270 return None 271 tag = cmp(key,h.key) 272 if tag == 0: 273 return h 274 if tag < 0: 275 return self.__floor(h.left,key) 276 t = self.__floor(h.right,key) 277 if t:#if find in right tree 278 return t 279 else:#else return itself 280 return h 281 282 def __ceil(self,h,key): 283 '''Find the NODE with key >= given key in the tree rooted at h ''' 284 if not h: 285 return None 286 tag = cmp(key,h.key) 287 if tag == 0: 288 return h 289 if tag > 0: # key is bigger 290 return self.__ceil(h.right,key) 291 t = self.__ceil(h.left,key)#key is lower.Try to find ceil left 292 if t:#if find in left tree 293 return t 294 else:#else return itself 295 return h 296 297 def __index(self,h,key): 298 if not h: 299 return 0 300 tag = cmp(key,h.key) 301 if tag < 0: 302 return self.__index(h.left,key) 303 elif tag > 0: #Key is bigger 304 return self.__index(h.right,key) + 1 + self.__size(h.left) 305 else: #Eq 306 return self.__size(h.left) 307 308 def __select(self,h,index): 309 '''assert h. assert 0 <= index < size(tree) ''' 310 l_size = self.__size(h.left) 311 if l_size > index: 312 return self.__select(h.left,index) 313 elif l_size < index: 314 return self.__select(h.right,index - l_size - 1) 315 else: 316 return h 317 318 def __range(self,h,q,lo,hi): 319 if not h: 320 return 321 tag_lo = cmp(lo,h.key) 322 tag_hi = cmp(hi,h.key) 323 if tag_lo < 0:#lo key is lower than h.key 324 self.__range(h.left,q,lo,hi) 325 if tag_lo <= 0 and tag_hi >= 0: 326 q.append(h.key) 327 if tag_hi > 0 :# hi key is bigger than h.key 328 self.__range(h.right,q,lo,hi) 329 330 331 #===============Adjust Functions=============# 332 def __rotate_right(self,h): 333 x = h.left 334 h.left,x.right = x.right,h 335 x.color,x.N = h.color,h.N 336 h.color,h.N = True,self.__size(h.left) + self.__size(h.right) + 1 337 return x 338 339 def __rotate_left(self,h): 340 x = h.right 341 h.right,x.left = x.left,h 342 x.color,x.N = h.color,h.N 343 h.color,h.N = True,self.__size(h.left) + self.__size(h.right) + 1 344 return x 345 346 def __flip_colors(self,h): 347 h.color = not h.color 348 h.left.color = not h.left.color 349 h.right.color = not h.right.color 350 351 def __move_red_left(self,h): 352 self.__flip_colors(h) 353 if self.__is_red(h.right.left): 354 h = self.__rotate_left(h) 355 return h 356 357 def __move_red_right(self,h): 358 self.__flip_colors(h) 359 if self.__is_red(h.left.left): 360 h = self.__rotate_right(h) 361 return h 362 363 def __balance(self,h): 364 if self.__is_red(h.right): 365 h = self.__rotate_left(h) 366 if self.__is_red(h.left) and self.__is_red(h.left.left): 367 h = self.__rotate_right(h) 368 if self.__is_red(h.left) and self.__is_red(h.right): 369 self.__flip_colors(h) 370 h.N = self.__size(h.left) + self.__size(h.right) + 1 371 return h 372 373 #Class Method 374 @staticmethod 375 def __is_red(x): 376 return False if not x else x.color 377 378 @staticmethod 379 def __is_black(x): 380 return True if not x else not x.color 381 382 @staticmethod 383 def __size(x): 384 return 0 if not x else x.N 385 386 387 def RBT_testing(): 388 '''API Examples ''' 389 t = RBT() 390 test_data = "SEARCHXMPL" 391 392 print '=====testing is_empty()\nBefore Insertion' 393 print t.is_empty() 394 395 for letter in test_data: 396 t.put(letter,[ord(letter)]) 397 print "Test Inserting:%s, tree size is %d" % (letter,t.size()) 398 print "After insertion it return:" 399 print t.is_empty() 400 print "====test is_empty complete\n" 401 402 403 print "=====Tesing Get method:" 404 print "get 's' is " 405 print t.get('S') 406 print "get 'H' is " 407 print t.get('H') 408 409 print '==Trying get null key: get "F" is' 410 print t.get('F') 411 print "=====Testing Get method end\n\n" 412 413 print "=====Testing ceil and floor" 414 print "Ceil('L')" 415 print t.ceil('L') 416 print "Ceil('F') *F is not in tree" 417 print t.ceil('F') 418 419 print "Floor('L')" 420 print t.ceil('L') 421 print "Floor('F')" 422 print t.ceil('F') 423 424 print '======test append method' 425 print 'Orient key e is correspond with' 426 print t.get('E') 427 t.append('E',[4]) 428 print '==After append' 429 print t.get('E') 430 print "=====Testing Append method end\n\n" 431 432 print "=====Testing index()" 433 print "index(E)" 434 print t.index('E') 435 print "index(L),select(4)" 436 print t.index('L'),t.select(4) 437 print "index('M'),select(5)" 438 print t.index('M'),t.select(5) 439 print "index a key not in tree:\n index('N'),select(6)" 440 print t.index('N'),t.select(6) 441 print "index('P')" 442 print t.index('P') 443 444 445 print "=====Testing select" 446 print "select(3) = " 447 print t.select(3) 448 print "select and index end...\n\n" 449 450 print "====Tesing Min and Max" 451 print "min key is:" 452 print t.min() 453 print "max key is" 454 print t.max() 455 456 print "==How much between min and max:" 457 print t.width(t.min(),t.max()) 458 print "keys between min and max:" 459 print t.keys() 460 print "keys in 'E' and 'M' " 461 print t.range('E','M') 462 463 464 print "try to delete min_key:" 465 print "But we could try contains('A') first" 466 print t.contains('A') 467 t.del_min() 468 print "After deletion t.contains('A') is " 469 print t.contains('A') 470 471 print t.min() 472 print "try to kill one more min key:" 473 t.del_min() 474 print t.min() 475 print "try to delete max_key,New Max key is :" 476 t.del_max() 477 print t.max() 478 print "=====Tesing Min and Max complete\n\n" 479 480 481 482 print '=====Deleting Test' 483 print t.size() 484 t.delete('H') 485 print t.size() 486 487 print 'Delete a non-exists key:' 488 try: 489 t.delete('F') 490 except: 491 print "*Look up error occur*" 492 493 print "=====Testing Delete method complete" 494 495 def test_basic_api(): 496 print "==========Testing Basic API==========" 497 t = RBT() 498 print "Test Data: FENGDIJKABCLM" 499 test_data = "FENGDIJKABCLM" #from A-N,without H 500 501 #=====put() 502 print "==========put() test begin!==========" 503 for letter in test_data: 504 t.put(letter,[ord(letter)]) #Value is [ascii order of letter] 505 print "put(%s); Now tree size is %d"%(letter,t.size()) 506 print 'Final tree size is %d'%t.size() 507 print "==========put() test complete!==========\n" 508 509 #=====get() 510 print "==========get() test begin!==========" 511 print "get('F'):\t%s"%repr(t.get('F')) 512 print "get('A'):\t%s"%repr(t.get('A')) 513 print "get a non-exist key Z: get('Z'):\t%s"%repr(t.get('Z')) 514 print "==========get() test complete!==========\n" 515 516 #=====append() 517 print "=====append() test begin!==========" 518 print "First append to a exist key:[F]" 519 print "Before Append:get('F'):\t%s"%repr(t.get('F')) 520 print "append('F',[3,'haha']):\t%s"%repr(t.append('F',[3,'haha'])) 521 print "After Append:get('F'):\t%s\n"%repr(t.get('F')) 522 print "Second append to a non-exist key:[O]" 523 print "Before Append:get('O'):\t%s"%repr(t.get('O')) 524 print "append a non-exist key O: append('O',['value of O']):\t%s"%repr(t.append('O',['value of O'])) 525 print "After Append:get('O'):\t%s\n"%repr(t.get('O')) 526 print "==========append() test complete!==========\n" 527 528 #=====delete() 529 print "==========delete() test begin!==========" 530 test_data2 = [x for x in test_data] 531 test_data2.reverse() 532 for letter in test_data2: 533 t.delete(letter) 534 print "delete(%s); Now tree size is %d"%(letter,t.size()) 535 print 'Final tree size is %d'%t.size() 536 print "==========delete() test complete!==========\n" 537 538 print "==========Basic API Test Complete==========\n\n" 539 540 def test_advance_api(): 541 print "==========Testing min max floor ceil above below ==========" 542 t = RBT() 543 print "Test Data: FENGDIJKABCLM" 544 test_data = "FENGDIJKABCLM" #from A-N,without H 545 for letter in test_data: 546 t.put(letter,[ord(letter)]) #Value is [ascii order of letter] 547 548 #=====min() and del_min() 549 print "==========min() and del_min() test begin!==========" 550 print "Original min():\t%s"%repr(t.min()) 551 print "run del_min()" 552 t.del_min() 553 print "After del_min:min()\t%s"%repr(t.min()) 554 555 print "run del_min() again" 556 t.del_min() 557 print "After del_min run again:min()\t%s"%repr(t.min()) 558 559 print "=====max() and del_max() test begin!" 560 print "Original max():\t%s"%repr(t.max()) 561 print "run del_max()" 562 t.del_max() 563 print "After del_max:max()\t%s"%repr(t.max()) 564 565 print "run del_max() again" 566 t.del_max() 567 print "After del_max run again:max()\t%s"%repr(t.max()) 568 print "==========min() max() del_min() del_max() test complete!==========\n" 569 570 def test_int_api(): 571 #======ceil floor above below 572 print "==========Testing ceil floor above below ==========" 573 t = RBT() 574 print "Test Data: FENGDIJKABCLM - [AHN] = FEGDIJKBCLM" 575 test_data = "FEGDIJKBCLM" #from A-N, Del A H N 576 577 for letter in test_data: 578 t.put(letter,[ord(letter)]) #Value is [ascii order of letter] 579 print "Node\tceil\t\tfloor\t\tabove\t\tbelow" 580 for P in ['A','B','C','G','H','I','L','M','N']: 581 print "%s\t%s\t%s\t%s\t%s"%(P,t.ceil(P),t.floor(P),t.above(P),t.below(P)) 582 583 if __name__ == '__main__': 584 test_basic_api() 585 test_advance_api() 586 test_int_api()
查找操作的数据结构不断进化,才有了红黑树:从链表到二叉平衡树,再到2-3树,最后到红黑树。
红黑树本质上是用二叉平衡树的形式来模拟2-3树的功能
《算法导论》也好,其他什么乱七八糟算法书博客也罢,讲红黑树都没讲到本质。
Sedgewick的《算法(4th)》这本书就很不错:起码他告诉你红黑树是怎么来的。
仔细理解2-3树与红黑树的相同之处,才能对那些乱七八糟的插入删除调整操作有直观的认识。