红黑树的Python实现

想用红黑树,怎么搜都搜不到现成的Python实现。干脆自己写一个。

算法的结构按照Sedgewick的《算法(4th)》一书第三章写成,略有改动。

完整的API实现,也用了一些测试case,暂时没发现问题。

这玩意就是好用,谁用谁知道。

废话不多说直接上代码。

#注意事项: 重载RBT.Node.reduce(self,new_val)来实现 append()方法中对已存在主键对应值的自定义合并操作。默认为调用list.extend()方法。

  1 #!/usr/bin/env python

  2 #coding: gbk

  3 

  4 ########################################################################

  5 #Author: Feng Ruohang

  6 #Create: 2014/10/06 11:38

  7 #Digest: Provide a common data struct: Red Black Tree

  8 ########################################################################

  9 

 10 class RBT(object):

 11     class Node(object):

 12         '''

 13         Node used in RBTree

 14         '''

 15         def __init__(self,key,value=None,color=False,N=1):

 16             self.key = key

 17             self.val = value

 18             self.color = color   #False for Black, True for Red

 19             self.N = N           #Total numbers of nodes in this subtree

 20             self.left = None

 21             self.right = None

 22 

 23         def __cmp__(l,r):

 24             return cmp(l.key,r.key)

 25 

 26         def __eq__(l,r):

 27             return True if l.key == r.key else False

 28 

 29         def __add__(l,r):

 30             l.value + r.value

 31 

 32         def reduce(self,new_val):

 33             self.val.extend(new_val)

 34 

 35     def __init__(self):

 36         self.root = None

 37 

 38     #====================APIs====================#

 39     #=====Basic API

 40 

 41     def get(self,key):

 42         return  self.__get(self.root,key)

 43 

 44     def put(self,key,val):

 45         self.root = self.__put(self.root,key,val)

 46         self.root.color = False 

 47 

 48     def append(self,key,val):

 49         self.root = self.__append(self.root,key,val)

 50         self.root.color = False

 51 

 52     def delete(self,key):

 53         if not self.contains(key):

 54             raise LookupError('No such keys in rbtree. Fail to Delete')

 55         if self.__is_black(self.root.left) and self.__is_black(self.root.right):

 56             self.root.color = True

 57         self.root = self.__delete(self.root,key)

 58         if not self.is_empty(): 

 59             self.root.color = False

 60 

 61     def del_min(self):

 62         if self.is_empty(): 

 63             raise LookupError('Empty Red-Black Tree. Can\'t delete min')

 64         if self.__is_black(self.root.left) and self.__is_black(self.root.right):

 65             self.root.color = True

 66         self.root = self.__del_min(self.root)

 67         if not self.is_empty(): self.root.color = False

 68 

 69     def del_max(self):

 70         if self.is_empty():

 71             raise LookupError('Empty Red-Black Tree. Can\'t delete max')

 72         if self.__is_black(self.root.left) and self.__is_black(self.root.right):

 73             self.root.color = True

 74         self.root = self.__del_max(self.root)

 75         if not self.is_empty(): self.root.color = False 

 76 

 77     def size(self):

 78         return self.__size(self.root)

 79 

 80     def is_empty(self):

 81         return not self.root

 82 

 83     def contains(self,key):

 84         return bool(self.get(key))

 85 

 86     #=====Advance API

 87     def min(self):

 88         if self.is_empty():

 89             return None

 90         return self.__min(self.root).key

 91 

 92     def max(self):

 93         if self.is_empty():

 94             return None

 95         return self.__max(self.root).key

 96 

 97     def floor(self,key):

 98         x = self.__floor(self.root,key)

 99         if x:

100             return x.key,x.val

101         else:

102             return None,None

103 

104     def ceil(self,key):

105         x = self.__ceil(self.root,key)

106         if x:

107             return x.key,x.val

108         else:

109             return None,None

110 

111     def below(self,key):

112         index = self.index(key)

113         if not 0 <= index - 1 < self.size():

114             return None,None    #Return None if out of range

115         x = self.__select(self.root,index - 1)

116         return x.key,x.val

117 

118     def above(self,key):

119         index = self.index(key)

120         if self.contains(key):

121             if not 0 <= index + 1 < self.size():

122                 return None,None    #Return None if out of range

123             else:

124                 x = self.__select(self.root,index+1)

125                 return x.key,x.val

126         else:#if key is not in tree. then select(i) is what we need

127             if not 0 <= index < self.size():

128                 return None,None    #Return None if out of range

129             else:

130                 x = self.__select(self.root,index)

131                 return x.key,x.val

132 

133     def index(self,key):

134         return self.__index(self.root,key) 

135 

136     def keys(self):

137         '''Return All Keys in the tree '''

138         return self.range(self.min(),self.max())

139 

140     def range(self,lo,hi):

141         '''Take two keys. return keys between them'''

142         q = []

143         self.__range(self.root,q,lo,hi)

144         return q

145 

146     def select(self,index):

147         '''Given Index Return Corresponding key '''

148         if not 0 <= index < self.size():

149             return None

150         return self.__select(self.root,index).key

151 

152     def width(self,lo,hi):

153         '''Return the numbers of keys between lo and hi '''

154         if lo > hi: 

155             return 0

156         if self.contains(hi):

157             return self.index(hi) - self.index(lo) + 1

158         else:

159             return self.index(hi) - self.index(lo)

160 

161 

162     #===============Private Method===============#

163     #=====Basic

164     def __get(self,x,key):

165         while x:

166             tag = cmp(key,x.key)

167             if tag < 0 : x = x.left

168             elif tag > 0 :x = x.right

169             else: return x.val

170 

171     def __put(self,h,key,val):

172         if not h: 

173             return self.Node(key,val,True,1)

174         tag = cmp(key,h.key)

175         if tag < 0: 

176             h.left = self.__put(h.left,key,val)

177         elif tag > 0: 

178             h.right = self.__put(h.right,key,val)

179         else: 

180             h.val = val   #Update

181 

182         if self.__is_black(h.left) and self.__is_red(h.right):

183             h = self.__rotate_left(h)

184         if self.__is_red(h.left) and self.__is_red(h.left.left):

185             h = self.__rotate_right(h)

186         if self.__is_red(h.left) and self.__is_red(h.right):

187             self.__flip_colors(h)

188         h.N = self.__size(h.left) + self.__size(h.right) + 1

189         return h

190 

191     def __append(self,h,key,val):

192         if not h: 

193             return self.Node(key,val,True,1)

194         tag = cmp(key,h.key)

195         if tag < 0: 

196             h.left = self.__append(h.left,key,val)

197         elif tag > 0: 

198             h.right = self.__append(h.right,key,val)

199         else: 

200             h.reduce(val)   #append.

201 

202         if self.__is_black(h.left) and self.__is_red(h.right):

203             h = self.__rotate_left(h)

204         if self.__is_red(h.left) and self.__is_red(h.left.left):

205             h = self.__rotate_right(h)

206         if self.__is_red(h.left) and self.__is_red(h.right):

207             self.__flip_colors(h)

208         h.N = self.__size(h.left) + self.__size(h.right) + 1

209         return h

210 

211     def __del_min(self,h):

212         if not h.left: #if h is empty:return None

213             return None

214 

215         if self.__is_black(h.left) and self.__is_black(h.left.left):

216             self.__move_red_left(h)

217         h.left = self.__del_min(h.left) #Del recursive

218         return self.__balance(h)

219 

220     def __del_max(self,h):

221         if self.__is_red(h.left): 

222             h = self.__rotate_right(h)

223         if not h.right: 

224             return None

225         if self.__is_black(h.right) and self.__is_black(h.right.left):

226             h = self.__move_red_right(h)

227         h.right = self.__del_max(h.right)

228         return self.__balance(h)

229 

230     def __delete(self,h,key):

231         if key < h.key:

232             if self.__is_black(h.left) and self.__is_black(h.left.left):

233                 h = self.__move_red_left(h)

234             h.left = self.__delete(h.left,key)

235         else:

236             if self.__is_red(h.left):

237                 h = self.__rotate_right(h)

238             if key == h.key and not h.right:

239                 return None

240             if self.__is_black(h.right) and self.__is_black(h.right.left):

241                 h = self.__move_red_right(h)

242             if key == h.key:#replace h with min of right subtree

243                 x = self.__min(h.right)

244                 h.key = x.key

245                 h.val = x.val

246                 h.right = self.__del_min(h.right)

247             else:

248                 h.right = self.__delete(h.right,key)

249         h = self.__balance(h)

250         return h

251     

252     #=====Advance

253     def __min(self,h):

254         #Assume h is not null

255         if not h.left:

256             return h

257         else:

258             return self.__min(h.left)

259 

260     def __max(self,h):

261         #Assume h is not null

262         if not h.right:

263             return h

264         else:

265             return self.__max(h.right)

266 

267     def __floor(self,h,key):

268         '''Find the NODE with key <= given key in the tree rooted at h '''

269         if not h:

270             return None

271         tag = cmp(key,h.key)

272         if tag == 0:

273             return h

274         if tag < 0:

275             return self.__floor(h.left,key)

276         t = self.__floor(h.right,key)

277         if t:#if find in right tree

278             return t

279         else:#else return itself

280             return h

281 

282     def __ceil(self,h,key):

283         '''Find the NODE with key >= given key in the tree rooted at h '''

284         if not h:

285             return None

286         tag = cmp(key,h.key)

287         if tag == 0:

288             return h

289         if tag > 0: # key is bigger

290             return self.__ceil(h.right,key)

291         t = self.__ceil(h.left,key)#key is lower.Try to find ceil left

292         if t:#if find in left tree

293             return t

294         else:#else return itself

295             return h

296 

297     def __index(self,h,key):

298         if not h:

299             return 0

300         tag = cmp(key,h.key)

301         if tag < 0:

302             return self.__index(h.left,key)

303         elif tag > 0:   #Key is bigger

304             return self.__index(h.right,key) + 1 + self.__size(h.left)

305         else:   #Eq

306             return self.__size(h.left)

307 

308     def __select(self,h,index):

309         '''assert h. assert 0 <= index < size(tree) '''

310         l_size = self.__size(h.left)

311         if l_size > index:

312             return self.__select(h.left,index)

313         elif l_size < index:

314             return self.__select(h.right,index - l_size - 1)

315         else:

316             return h

317 

318     def __range(self,h,q,lo,hi):

319         if not h: 

320             return

321         tag_lo = cmp(lo,h.key)

322         tag_hi = cmp(hi,h.key)

323         if tag_lo < 0:#lo key is lower than h.key

324             self.__range(h.left,q,lo,hi)

325         if tag_lo <= 0 and tag_hi >= 0:

326             q.append(h.key)

327         if tag_hi > 0 :# hi key is bigger than h.key

328             self.__range(h.right,q,lo,hi)

329 

330 

331     #===============Adjust Functions=============#

332     def __rotate_right(self,h):

333         x = h.left

334         h.left,x.right = x.right,h

335         x.color,x.N = h.color,h.N

336         h.color,h.N = True,self.__size(h.left) + self.__size(h.right) + 1

337         return x

338 

339     def __rotate_left(self,h):

340         x = h.right

341         h.right,x.left = x.left,h

342         x.color,x.N = h.color,h.N

343         h.color,h.N = True,self.__size(h.left) + self.__size(h.right) + 1

344         return x

345 

346     def __flip_colors(self,h):

347         h.color = not h.color

348         h.left.color = not h.left.color

349         h.right.color = not h.right.color

350 

351     def __move_red_left(self,h):

352         self.__flip_colors(h)

353         if self.__is_red(h.right.left):

354             h = self.__rotate_left(h)

355         return h

356 

357     def __move_red_right(self,h):

358         self.__flip_colors(h)

359         if self.__is_red(h.left.left):

360             h = self.__rotate_right(h)

361         return h

362 

363     def __balance(self,h):

364         if self.__is_red(h.right): 

365             h = self.__rotate_left(h)

366         if self.__is_red(h.left) and self.__is_red(h.left.left): 

367             h = self.__rotate_right(h)

368         if self.__is_red(h.left) and self.__is_red(h.right):

369             self.__flip_colors(h)

370         h.N = self.__size(h.left) + self.__size(h.right) + 1

371         return h

372     

373     #Class Method

374     @staticmethod

375     def __is_red(x):

376         return False if not x else x.color

377 

378     @staticmethod

379     def __is_black(x):

380         return True  if not x else not x.color

381 

382     @staticmethod

383     def __size(x):

384         return 0 if not x else x.N

385 

386 

387 def RBT_testing():

388     '''API Examples '''

389     t = RBT()

390     test_data = "SEARCHXMPL"

391 

392     print '=====testing is_empty()\nBefore Insertion'

393     print t.is_empty()

394 

395     for letter in test_data:

396         t.put(letter,[ord(letter)])

397         print "Test Inserting:%s, tree size is %d" % (letter,t.size())

398     print "After insertion it return:"

399     print t.is_empty()

400     print  "====test is_empty complete\n"

401 

402 

403     print "=====Tesing Get method:"

404     print "get 's' is "

405     print t.get('S')

406     print "get 'H' is "

407     print t.get('H')

408 

409     print '==Trying get null key: get "F" is'

410     print t.get('F')

411     print "=====Testing Get method end\n\n"

412 

413     print "=====Testing ceil and floor"

414     print "Ceil('L')"

415     print t.ceil('L')

416     print "Ceil('F') *F is not in tree"

417     print t.ceil('F')

418 

419     print "Floor('L')"

420     print t.ceil('L')

421     print "Floor('F')"

422     print t.ceil('F')

423 

424     print '======test append method'

425     print 'Orient key e is correspond with'

426     print t.get('E')

427     t.append('E',[4])

428     print '==After append'

429     print t.get('E')

430     print "=====Testing Append method end\n\n" 

431 

432     print "=====Testing index()"

433     print "index(E)" 

434     print t.index('E')

435     print "index(L),select(4)"

436     print t.index('L'),t.select(4)

437     print "index('M'),select(5)"

438     print t.index('M'),t.select(5)

439     print "index a key not in tree:\n index('N'),select(6)"

440     print t.index('N'),t.select(6)

441     print "index('P')"

442     print t.index('P')

443     

444     

445     print "=====Testing select"

446     print "select(3) = "

447     print t.select(3)

448     print "select and index end...\n\n"

449 

450     print "====Tesing Min and Max"

451     print "min key is:"

452     print t.min()

453     print "max key is"

454     print t.max()

455 

456     print "==How much between min and max:"

457     print t.width(t.min(),t.max())

458     print "keys between min and max:"

459     print t.keys()

460     print "keys in 'E' and 'M' "

461     print t.range('E','M')

462 

463 

464     print "try to delete min_key:"

465     print "But we could try contains('A') first"

466     print t.contains('A')

467     t.del_min()

468     print "After deletion t.contains('A') is "

469     print t.contains('A')

470 

471     print t.min()

472     print "try to kill one more min key:"

473     t.del_min()

474     print t.min()

475     print "try to delete max_key,New Max key is :"

476     t.del_max()

477     print t.max()

478     print "=====Tesing Min and Max complete\n\n"

479 

480 

481 

482     print '=====Deleting Test'

483     print t.size()

484     t.delete('H')

485     print t.size()

486 

487     print 'Delete a non-exists key:'

488     try:    

489         t.delete('F')

490     except:

491         print "*Look up error occur*"

492 

493     print "=====Testing Delete method complete"

494 

495 def test_basic_api():

496     print "==========Testing Basic API=========="

497     t = RBT()

498     print "Test Data: FENGDIJKABCLM"

499     test_data = "FENGDIJKABCLM" #from A-N,without H

500 

501     #=====put()

502     print "==========put() test begin!=========="

503     for letter in test_data:

504         t.put(letter,[ord(letter)]) #Value is [ascii order of letter]

505         print "put(%s); Now tree size is %d"%(letter,t.size())

506     print 'Final tree size is %d'%t.size()

507     print "==========put() test complete!==========\n"

508 

509     #=====get()

510     print "==========get() test begin!=========="

511     print "get('F'):\t%s"%repr(t.get('F'))

512     print "get('A'):\t%s"%repr(t.get('A'))

513     print "get a non-exist key Z: get('Z'):\t%s"%repr(t.get('Z'))

514     print "==========get() test complete!==========\n"

515 

516     #=====append()

517     print "=====append() test begin!=========="

518     print "First append to a exist key:[F]"

519     print "Before Append:get('F'):\t%s"%repr(t.get('F'))

520     print "append('F',[3,'haha']):\t%s"%repr(t.append('F',[3,'haha']))

521     print "After Append:get('F'):\t%s\n"%repr(t.get('F'))

522     print "Second append to a non-exist key:[O]"

523     print "Before Append:get('O'):\t%s"%repr(t.get('O'))

524     print "append a non-exist key O: append('O',['value of O']):\t%s"%repr(t.append('O',['value of O']))

525     print "After Append:get('O'):\t%s\n"%repr(t.get('O'))

526     print "==========append() test complete!==========\n"

527 

528     #=====delete()

529     print "==========delete() test begin!=========="

530     test_data2 = [x for x in test_data]

531     test_data2.reverse()

532     for letter in test_data2:

533         t.delete(letter)

534         print "delete(%s); Now tree size is %d"%(letter,t.size())

535     print 'Final tree size is %d'%t.size()

536     print "==========delete() test complete!==========\n"

537 

538     print "==========Basic API Test Complete==========\n\n"

539 

540 def test_advance_api():

541     print "==========Testing min max floor ceil above below =========="

542     t = RBT()

543     print "Test Data: FENGDIJKABCLM"

544     test_data = "FENGDIJKABCLM" #from A-N,without H

545     for letter in test_data:

546         t.put(letter,[ord(letter)]) #Value is [ascii order of letter]

547 

548     #=====min() and del_min()

549     print "==========min() and del_min() test begin!=========="

550     print "Original min():\t%s"%repr(t.min())

551     print "run del_min()"

552     t.del_min()

553     print "After del_min:min()\t%s"%repr(t.min())

554 

555     print "run del_min() again"

556     t.del_min()

557     print "After del_min run again:min()\t%s"%repr(t.min())

558 

559     print "=====max() and del_max() test begin!"

560     print "Original max():\t%s"%repr(t.max())

561     print "run del_max()"

562     t.del_max()

563     print "After del_max:max()\t%s"%repr(t.max())

564 

565     print "run del_max() again"

566     t.del_max()

567     print "After del_max run again:max()\t%s"%repr(t.max())

568     print "==========min() max() del_min() del_max() test complete!==========\n"

569 

570 def test_int_api():

571     #======ceil floor above below

572     print "==========Testing ceil floor above below =========="

573     t = RBT()

574     print "Test Data: FENGDIJKABCLM - [AHN] = FEGDIJKBCLM"

575     test_data = "FEGDIJKBCLM" #from A-N, Del A H N

576 

577     for letter in test_data:

578         t.put(letter,[ord(letter)]) #Value is [ascii order of letter]

579     print "Node\tceil\t\tfloor\t\tabove\t\tbelow"

580     for P in ['A','B','C','G','H','I','L','M','N']:

581         print "%s\t%s\t%s\t%s\t%s"%(P,t.ceil(P),t.floor(P),t.above(P),t.below(P))

582 

583 if __name__ == '__main__':

584     test_basic_api()

585     test_advance_api()

586     test_int_api()
View Code

查找操作的数据结构不断进化,才有了红黑树:从链表到二叉平衡树,再到2-3树,最后到红黑树。

红黑树本质上是用二叉平衡树的形式来模拟2-3树的功能

《算法导论》也好,其他什么乱七八糟算法书博客也罢,讲红黑树都没讲到本质。

Sedgewick的《算法(4th)》这本书就很不错:起码他告诉你红黑树是怎么来的。 

 仔细理解2-3树与红黑树的相同之处,才能对那些乱七八糟的插入删除调整操作有直观的认识。

你可能感兴趣的:(python)