跳表SkipList(go语言实现)

网上找到一个开源的代码

https://github.com/AceDarkknight/ConcurrentSkipList

数据结构

我们用 ConcurrentSkipList 这个数据结构代表整个 skip list,可以看到里面是一个包含多个 skipList 的切片。

// ConcurrentSkipList is a struct contains a slice of concurrent skip list.
type ConcurrentSkipList struct {
	skipLists []*skipList
	level     int
}

skipList 的结构如下, 每个 skipList 除了有头结点、尾节点、高度、长度外都有一把读写锁,负责保证并发安全。

type skipList struct {
	level  int
	length int32
	head   *Node
	tail   *Node
	mutex  sync.RWMutex
}

其中我们把每个节点称为一个 Node,Node 的结构如下,index 代表节点的索引值,value 代表节点的值,nextNodes 记录了该节点指向的下个节点。

type Node struct {
	index     uint64
	value     interface{}
	nextNodes []*Node
}

concurrentSkipList.go源文件

/*
Package ConcurrentSkipList provide an implementation of skip list. It's thread-safe in concurrency and high performance.
*/
package ConcurrentSkipList

import (
	"errors"
	"math"
	"sync/atomic"

	"github.com/OneOfOne/xxhash"
)

// Comes from redis's implementation.
// Also you can see more detail in William Pugh's paper .
// The paper is in ftp://ftp.cs.umd.edu/pub/skipLists/skiplists.pdf
const (
	MAX_LEVEL   = 32
	PROBABILITY = 0.25
	SHARDS      = 32
)

// shardIndex is used to indicate which shard a given index belong to.
var shardIndexes = make([]uint64, SHARDS)

// init will initialize the shardIndexes.
func init() {
	var step uint64 = 1 << 59 // 2^64/SHARDS
	var t uint64 = math.MaxUint64

	for i := SHARDS - 1; i >= 0; i-- {
		shardIndexes[i] = t
		t -= step
	}
}

// ConcurrentSkipList is a struct contains a slice of concurrent skip list.
type ConcurrentSkipList struct {
	skipLists []*skipList
	level     int
}

// NewConcurrentSkipList will create a new concurrent skip list with given level.
// Level must between 1 to 32. If not, will return an error.
// To determine the level, you can see the paper ftp://ftp.cs.umd.edu/pub/skipLists/skiplists.pdf.
// A simple way to determine the level is L(N) = log(1/PROBABILITY)(N).
// N is the count of the skip list which you can estimate. PROBABILITY is 0.25 in this case.
// For example, if you expect the skip list contains 10000000 elements, then N = 10000000, L(N) ≈ 12.
// After initialization, the head field's level equal to level parameter and point to tail field.
func NewConcurrentSkipList(level int) (*ConcurrentSkipList, error) {
	if level <= 0 || level > MAX_LEVEL {
		return nil, errors.New("invalid level, level must between 1 to 32")
	}

	skipLists := make([]*skipList, SHARDS, SHARDS)
	for i := 0; i < SHARDS; i++ {
		skipLists[i] = newSkipList(level)
	}

	return &ConcurrentSkipList{
		skipLists: skipLists,
		level:     level,
	}, nil
}

// Level will return the level of skip list.
func (s *ConcurrentSkipList) Level() int {
	return s.level
}

// Length will return the length of skip list.
func (s *ConcurrentSkipList) Length() int32 {
	var length int32
	for _, sl := range s.skipLists {
		length += sl.getLength()
	}

	return length
}

// Search will search the skip list with the given index.
// If the index exists, return the value and true, otherwise return nil and false.
func (s *ConcurrentSkipList) Search(index uint64) (*Node, bool) {
	sl := s.skipLists[getShardIndex(index)]
	if atomic.LoadInt32(&sl.length) == 0 {
		return nil, false
	}

	result := sl.searchWithoutPreviousNodes(index)
	return result, result != nil
}

// Insert will insert a value into skip list. If skip has these this index, overwrite the value, otherwise add it.
func (s *ConcurrentSkipList) Insert(index uint64, value interface{}) {
	// Ignore nil value.
	if value == nil {
		return
	}

	sl := s.skipLists[getShardIndex(index)]
	sl.insert(index, value)
}

// Delete the node with the given index.
func (s *ConcurrentSkipList) Delete(index uint64) {
	sl := s.skipLists[getShardIndex(index)]
	if atomic.LoadInt32(&sl.length) == 0 {
		return
	}

	sl.delete(index)
}

// ForEach will create a snapshot first shard by shard. Then iterate each node in snapshot and do the function f().
// If f() return false, stop iterating and return.
// If skip list is inserted or deleted while iterating, the node in snapshot will not change.
// The performance is not very high and the snapshot with be stored in memory.
func (s *ConcurrentSkipList) ForEach(f func(node *Node) bool) {
	for _, sl := range s.skipLists {
		if sl.getLength() == 0 {
			continue
		}

		nodes := sl.snapshot()
		stop := false
		for _, node := range nodes {
			if !f(node) {
				stop = true
				break
			}
		}

		if stop {
			break
		}
	}
}

// Sub will return a slice the skip list who starts with startNumber.
// The startNumber start with 0 as same as slice and maximum length is skip list's length.
func (s *ConcurrentSkipList) Sub(startNumber int32, length int32) []*Node {
	// Ignore invalid parameter.
	if startNumber > s.Length() || startNumber < 0 || length <= 0 {
		return nil
	}

	var result []*Node
	var position, count int32 = 0, 0
	for _, sl := range s.skipLists {
		if l := sl.getLength(); l == 0 || position+l <= startNumber {
			position += l
			continue
		}

		nodes := sl.snapshot()
		for _, node := range nodes {
			if position < startNumber {
				position++
				continue
			}

			if count == length {
				break
			}

			result = append(result, node)
			count++
		}

		if count == length {
			break
		}
	}

	return result
}

// Locate which shard the given index belong to.
func getShardIndex(index uint64) int {
	result := -1
	for i, t := range shardIndexes {
		if index <= t {
			result = i
			break
		}
	}

	return result
}

// Hash will calculate the input's hash value using xxHash algorithm.
// It can be used to calculate the index of skip list.
// See more detail in https://cyan4973.github.io/xxHash/
func Hash(input []byte) uint64 {
	h := xxhash.New64()
	h.Write(input)
	return h.Sum64()
}

skipList.go源文件

package ConcurrentSkipList

import (
	"math/rand"
	"sync"
	"sync/atomic"
)

type skipList struct {
	level  int
	length int32
	head   *Node
	tail   *Node
	mutex  sync.RWMutex
}

// newSkipList will create a concurrent skip list with given level.
func newSkipList(level int) *skipList {
	head := newNode(0, nil, level)
	var tail *Node
	for i := 0; i < len(head.nextNodes); i++ {
		head.nextNodes[i] = tail
	}

	return &skipList{
		level:  level,
		length: 0,
		head:   head,
		tail:   tail,
	}
}

// searchWithPreviousNode will search given index in skip list.
// The first return value represents the previous nodes need to update when call Insert function.
// The second return value represents the value with given index or the closet value whose index is larger than given index.
func (s *skipList) searchWithPreviousNodes(index uint64) ([]*Node, *Node) {
	// Store all previous value whose index is less than index and whose next value's index is larger than index.
	previousNodes := make([]*Node, s.level)

	// fmt.Printf("start doSearch:%v\n", index)
	currentNode := s.head

	// Iterate from top level to bottom level.
	for l := s.level - 1; l >= 0; l-- {
		// Iterate value util value's index is >= given index.
		// The max iterate count is skip list's length. So the worst O(n) is N.
		for currentNode.nextNodes[l] != s.tail && currentNode.nextNodes[l].index < index {
			currentNode = currentNode.nextNodes[l]
		}

		// When next value's index is >= given index, add current value whose index < given index.
		previousNodes[l] = currentNode
	}

	// Avoid point to tail which will occur panic in Insert and Delete function.
	// When the next value is tail.
	// The index is larger than the maximum index in the skip list or skip list's length is 0. Don't point to tail.
	// When the next value isn't tail.
	// Next value's index must >= given index. Point to it.
	if currentNode.nextNodes[0] != s.tail {
		currentNode = currentNode.nextNodes[0]
	}
	// fmt.Printf("previous value:\n")
	// for _, n := range previousNodes {
	// 	fmt.Printf("%p\t", n)
	// }
	// fmt.Println()
	// fmt.Printf("end doSearch %v\n", index)

	return previousNodes, currentNode
}

// searchWithoutPreviousNodes will return the value whose index is given index.
// If can not find the given index, return nil.
// This function is faster than searchWithPreviousNodes and it used to only searching index.
func (s *skipList) searchWithoutPreviousNodes(index uint64) *Node {
	currentNode := s.head

	// Read lock and unlock.
	s.mutex.RLock()
	defer s.mutex.RUnlock()

	// Iterate from top level to bottom level.
	for l := s.level - 1; l >= 0; l-- {
		// Iterate value util value's index is >= given index.
		// The max iterate count is skip list's length. So the worst O(n) is N.
		for currentNode.nextNodes[l] != s.tail && currentNode.nextNodes[l].index < index {
			currentNode = currentNode.nextNodes[l]
		}
	}

	currentNode = currentNode.nextNodes[0]
	if currentNode == s.tail || currentNode.index > index {
		return nil
	} else if currentNode.index == index {
		return currentNode
	} else {
		return nil
	}
}

// insert will insert a value into skip list and update the length.
// If skip has these this index, overwrite the value, otherwise add it.
func (s *skipList) insert(index uint64, value interface{}) {
	// Write lock and unlock.
	s.mutex.Lock()
	defer s.mutex.Unlock()

	previousNodes, currentNode := s.searchWithPreviousNodes(index)

	if currentNode != s.head && currentNode.index == index {
		currentNode.value = value
		return
	}

	// Make a new value.
	newNode := newNode(index, value, s.randomLevel())

	// Adjust pointer. Similar to update linked list.
	for i := len(newNode.nextNodes) - 1; i >= 0; i-- {
		// Firstly, new value point to next value.
		newNode.nextNodes[i] = previousNodes[i].nextNodes[i]

		// Secondly, previous nodes point to new value.
		previousNodes[i].nextNodes[i] = newNode

		// Finally, in order to release the slice, point to nil.
		previousNodes[i] = nil
	}

	atomic.AddInt32(&s.length, 1)

	for i := len(newNode.nextNodes); i < len(previousNodes); i++ {
		previousNodes[i] = nil
	}
}

// delete will find the index is existed or not firstly.
// If existed, delete it and update length, otherwise do nothing.
func (s *skipList) delete(index uint64) {
	// Write lock and unlock.
	s.mutex.Lock()
	defer s.mutex.Unlock()

	previousNodes, currentNode := s.searchWithPreviousNodes(index)

	// If skip list length is 0 or could not find value with the given index.
	if currentNode != s.head && currentNode.index == index {
		// Adjust pointer. Similar to update linked list.
		for i := 0; i < len(currentNode.nextNodes); i++ {
			previousNodes[i].nextNodes[i] = currentNode.nextNodes[i]
			currentNode.nextNodes[i] = nil
			previousNodes[i] = nil
		}

		atomic.AddInt32(&s.length, -1)
	}

	for i := len(currentNode.nextNodes); i < len(previousNodes); i++ {
		previousNodes[i] = nil
	}
}

// snapshot will create a snapshot of the skip list and return a slice of the nodes.
func (s *skipList) snapshot() []*Node {
	s.mutex.RLock()
	defer s.mutex.RUnlock()

	result := make([]*Node, s.length)
	i := 0

	currentNode := s.head.nextNodes[0]
	for currentNode != s.tail {
		node := &Node{
			index:     currentNode.index,
			value:     currentNode.value,
			nextNodes: nil,
		}

		result[i] = node
		currentNode = currentNode.nextNodes[0]
		i++
	}

	return result
}

// getLength will return the length of skip list.
func (s *skipList) getLength() int32 {
	return atomic.LoadInt32(&s.length)
}

// randomLevel will generate and random level that level > 0 and level < skip list's level
// This comes from redis's implementation.
func (s *skipList) randomLevel() int {
	level := 1
	for rand.Float64() < PROBABILITY && level < s.level {
		level++
	}

	return level
}

node.go源文件

package ConcurrentSkipList

type Node struct {
	index     uint64
	value     interface{}
	nextNodes []*Node
}

// newNode will create a node using in this package but not external package.
func newNode(index uint64, value interface{}, level int) *Node {
	return &Node{
		index:     index,
		value:     value,
		nextNodes: make([]*Node, level, level),
	}
}

// Index will return the node's index.
func (n *Node) Index() uint64 {
	return n.index
}

// Value will return the node's value.
func (n *Node) Value() interface{} {
	return n.value
}

 

你可能感兴趣的:(数据结构)