JUC--ConcurrentSkipListMap源码分析(基于JDK1.8)

1 概述

我们知道JDK针对键值对的存储提供了HashMap、TreeMap、HashTable和ConcurrentHashMap等工具。这些工具在多线程的情况下,使存入的数据有序同时方便数据的遍历,就显得不那么理想了,所以JDK又为我们提供了一个工具ConcurrentSkipListMap。这个工具能够保证线程安全,同时保证插入数据使有序的(根据key来排序)。下面我们就开始来分析一下ConcurrentSkipListMap的实现原理。

2 跳表数据结构简介

HashMap是基于Hash表实现的,针对插入、删除和定位元素的效率较高。如果我们要根据自然顺序或者自定义顺序来遍历元素(key),这个时候我们就可以使用TreeMap了,因为TreeMap是基于红黑树来实现。如果要保证线程安全呢?那就可以使用HashTable,但是HashTable的效率比较低,所以这个时候就该我们的ConcurrentHashMap上场了,然后ConcurrentHashMap没法按照key来排序,所以要满足上面所有的条件,就可以选择ConcurrentSkipListMap。

ConcurrentSkipListMap是基于跳表来实现的。

传统意义的单链表是一个线性结构,向有序的链表中插入一个节点需要O(n)的时间,查找操作需要O(n)的时间。

跳跃表的简单示例:

如果我们使用上图所示的跳跃表,就可以减少查找所需时间为O(n/2),因为我们可以先通过每个节点的最上面的指针先进行查找,这样子就能跳过一半的节点。

比如我们想查找19,首先和6比较,大于6之后,在和9进行比较,然后在和12进行比较......最后比较到21的时候,发现21大于19,说明查找的点在17和21之间,从这个过程中,我们可以看出,查找的时候跳过了3、7、12等点,因此查找的复杂度为O(n/2)。

查找的过程如下图:

其实,上面基本上就是跳跃表的思想,每一个结点不单单只包含指向下一个结点的指针,可能包含很多个指向后续结点的指针,这样就可以跳过一些不必要的结点,从而加快查找、删除等操作。对于一个链表内每一个结点包含多少个指向后续元素的指针,后续节点个数是通过一个随机函数生成器得到,这样子就构成了一个跳跃表。
随机生成的跳跃表可能如下图所示:

跳跃表其实是一种通过空间换取时间的算法,通过在每个节点上增加向前的指针来提高查询效率。目前开源软件redis就使用了跳表这种数据结构。

3 源码分析

3.1 内部类

针对源码的分析,这里我们首先来看一下内部类,因为这里面定义了ConcurrentSkipListMap存储节点的数据结构。

(1)Node

static final class Node {
        final K key;
        volatile Object value;
        volatile Node next;
... ... 

Node是用于存储每个节点数据的类,里面包含了键值对的数据和指向下一个节点的引用。

(2)Index

static class Index {
        final Node node;
        final Index down;
        volatile Index right;
... ...

Index是用于存储跳跃表数据结构中的索引的类,从这里我们可以看出针对同一个Node,可能有多个Index。

Index中的node就是当前索引包含的节点,down指向的是相同节点Node的下一个索引,right指向的是右边的索引(即,另一个节点上的索引)。

(3)HeadIndex

static final class HeadIndex extends Index {
        final int level;
        HeadIndex(Node node, Index down, Index right, int level) {
            super(node, down, right);
            this.level = level;
        }
    }

HeadIndex是Index类的子类,增加了level属性,用于记录下跳跃表的索引层级数。

针对内部类的分析就到这里,其余的内容可以自行查看源码。接下来我们来看看部分核心函数的实现。

3.2 函数

(1)get函数

针对get函数,通过查看源码,我们可以知道它直接调用了doGet函数,所以这里我们直接分析doGet函数,并在里面添加注释来说明。

private V doGet(Object key) {
        if (key == null)
            throw new NullPointerException();

        //保存自定义比较器
        Comparator cmp = comparator;

        //无限循环直到:获取到指定key的值,或者key不存在
        outer: for (;;) {

            //获取key最近的节点Node(在要找节点的左边),然后向后遍历链
            for (Node b = findPredecessor(key, cmp), n = b.next;;) {
                Object v; int c;

                //查找key的值不存在
                if (n == null)
                    break outer;
                Node f = n.next;

                //为了保证在读取到next并处理的过程中,next值没有改变
                if (n != b.next)                // inconsistent read
                    break;

                //如果节点n已经删除,则直接删除n节点
                if ((v = n.value) == null) {
                    n.helpDelete(b, f);
                    break;
                }

                //
                if (b.value == null || v == n)
                    break;

                //比较key和节点n的key值,如果相等则获取并返回
                if ((c = cpr(cmp, key, n.key)) == 0) {
                    @SuppressWarnings("unchecked") V vv = (V)v;
                    return vv;
                }

                //没有合适的值
                if (c < 0)
                    break outer;

                //赋值遍历
                b = n;
                n = f;
            }
        }
        return null;
    }

上面的函数,其实就是找到距离要找节点最近的Node,然后开始从Node开始向后遍历。这里有一个比较关键的函数findPredecessor,这里面就体现了跳跃表的有点。

private Node findPredecessor(Object key, Comparator cmp) {
        if (key == null)
            throw new NullPointerException(); 
        for (;;) {

            //从头索引开始遍历
            for (Index q = head, r = q.right, d;;) {
                if (r != null) {
                    Node n = r.node;
                    K k = n.key;
                    if (n.value == null) {

                        //删除r
                        if (!q.unlink(r))
                            break;           // restart
                        r = q.right;         // reread r
                        continue;
                    }

                    //查找的key大于k,继续向右边遍历
                    if (cpr(cmp, key, k) > 0) {
                        q = r;
                        r = r.right;
                        continue;
                    }
                }

                //r == null或者查找的key小于k的情况,获取下级索引开始遍历,如果没有下级索引,就直接返回
                if ((d = q.down) == null)
                    return q.node;
                q = d;
                r = d.right;
            }
        }
    }

上面就是对get函数的分析,接下来我们来看看put函数。

(2)put函数

同样,查看源码,我们很容易发现put函数直接调用了doPut函数,所以这里我们直接分析doPut函数。

private V doPut(K key, V value, boolean onlyIfAbsent) {
        Node z;           
        if (key == null)
            throw new NullPointerException();
		
	//获取比较器
        Comparator cmp = comparator;
        outer: for (;;) {
			
	    //遍历节点(针对findPredecessor函数的分析在前面已经进行)
            for (Node b = findPredecessor(key, cmp), n = b.next;;) {
                if (n != null) {
                    Object v; int c;
                    Node f = n.next;
					
		    //数据读取不一致,break,从新进行
                    if (n != b.next)
                        break;
					
		    //如果节点n已经被删除,则删除节点n
                    if ((v = n.value) == null) {
                        n.helpDelete(b, f);
                        break;
                    }
					
		    //如果节点b已经被删除,则删除节点b
                    if (b.value == null || v == n) 
                        break;
						
		    //查找的key大于节点k,则继续遍历
                    if ((c = cpr(cmp, key, n.key)) > 0) {
                        b = n;
                        n = f;
                        continue;
                    }
					
		    //查找的key值等于节点的key
                    if (c == 0) {
						
			//替换掉key对应的value,并返回就得value
                        if (onlyIfAbsent || n.casValue(v, value)) {
                            @SuppressWarnings("unchecked") V vv = (V)v;
                            return vv;
                        }
                        break; 
                    }
                }
                
		//针对插入键值对的key小于n的key的情况,创建新的节点,将新节点的下一节点指向n,并且b节点的下一节点指向新节点。
                z = new Node(key, value, n);
                if (!b.casNext(n, z))
                    break;
                break outer;
            }
        }

		//以下是随机索引级别操作
		
		//随机生成种子
        int rnd = ThreadLocalRandom.nextSecondarySeed();
        if ((rnd & 0x80000001) == 0) { // test highest and lowest bits
            int level = 1, max;
			
	    //判断从右到左有多少个1,从而设置索引级数
            while (((rnd >>>= 1) & 1) != 0)
                ++level;
            Index idx = null;
			
	    //保存头节点
            HeadIndex h = head;
			
	    //小于跳表的层级
            if (level <= (max = h.level)) {
				
		//为节点z生成对应的index节点,并且赋值index节点的down属性
                for (int i = 1; i <= level; ++i)
                    idx = new Index(z, idx, null);
            }
            else { // try to grow by one level
                level = max + 1; // hold in array and later pick the one to use
				
		//生成index节点的数组,其中idxs[0]不加入使用
                @SuppressWarnings("unchecked")Index[] idxs =
                    (Index[])new Index[level+1];
					
		//从下到上生成index节点,并赋值down属性	
                for (int i = 1; i <= level; ++i)
                    idxs[i] = idx = new Index(z, idx, null);
				
		//无限循环处理头节点(针对level超出头节点level的部分)
                for (;;) {
                    h = head;
                    int oldLevel = h.level;
                    if (level <= oldLevel) // lost race to add level
                        break;
                    HeadIndex newh = h;
                    Node oldbase = h.node;
					
		    //为每一层生成一个新的头节点索引
                    for (int j = oldLevel+1; j <= level; ++j)
                        newh = new HeadIndex(oldbase, newh, idxs[j], j);
					
		    //比较替换调头节点
                    if (casHead(h, newh)) {
                        h = newh;
                        idx = idxs[level = oldLevel];
                        break;
                    }
                }
            }
            
	    //插入Index结点
            splice: for (int insertionLevel = level;;) {
				
		//保存新跳表的层数
                int j = h.level;
                for (Index q = h, r = q.right, t = idx;;) {
                    if (q == null || t == null)
                        break splice;
                    if (r != null) {
                        Node n = r.node;
                        // compare before deletion check avoids needing recheck
                        int c = cpr(cmp, key, n.key);
						
			// 结点的值为空,表示需要删除
                        if (n.value == null) {
							
			    // 删除q的Index结点
                            if (!q.unlink(r))
                                break;
                            r = q.right;
                            continue;
                        }
                        if (c > 0) {
                            q = r;
                            r = r.right;
                            continue;
                        }
                    }

                    if (j == insertionLevel) {
                        if (!q.link(r, t))
                            break; // restart
                        if (t.node.value == null) {
                            findNode(key);
                            break splice;
                        }
                        if (--insertionLevel == 0)
                            break splice;
                    }

                    if (--j >= insertionLevel && j < level)
                        t = t.down;
                    q = q.down;
                    r = q.right;
                }
            }
        }
        return null;
    }

上面代码加上注释就是对put函数的分析,比较简单,这里就不进一步分析了,后面有时间再回过头来完善一下。

可以看出,这里保证线程安全使用了大量的CAS操作。

针对其余函数的分析,大家可以自行查看源码,比较简单。欢迎交流,谢谢。

你可能感兴趣的:(Java高并发)