ConcurrentHashMap源码详解

1. ConcurrentHashMap概述

ConcurrentHashMap是线程安全的哈希表,不同于HashTable,后者在方法上增加synchronized关键字,利用对象同步锁实现线程之间的同步。显然,HashTable实现线程安全的方式太“重”,并发度高的情况下,很多线程争用同一把锁,吞吐量较低。

ConcurrentHashMap通过锁分段技术,只有在同一个段内,才会存在锁竞争,提高了并发处理能力。它的内部数据结构其实是一个Segment数组,该数组的大小代表了ConcurrentHashMap的并发度,Segment同时也是一把可重入锁,该锁用来确保该段数据并发访问的线程安全。每一个Segment其实是一个类似于HashMap的哈希表,用来存储key-value。看下ConcurrentHashMap结构图:

ConcurrentHashMap源码详解_第1张图片

ConcurrentHashMap维护了一个Segment数组segments,每个Segment是一个哈希表。当线程需要访问segments[1]处的哈希表,首先需要获取该段的锁,然后才能访问该段的哈希表。上图中segments数组大小为8,因此并发度为8,最多支持8个线程在不同的段同时访问。

2. HashEntry

HashEntry代表了哈希表的一个key-value项,它是ConcurrentHashMap的一个内部静态类,看下HashEntry的数据结构:

static final class HashEntry {
    final int hash;
    final K key;
    volatile V value;
    volatile HashEntry next;

    HashEntry(int hash, K key, V value, HashEntry next) {
      this.hash = hash;
      this.key = key;
      this.value = value;
      this.next = next;
    }
    //……
}

HashEntry数据结构也很简单,它是一个单链表结点,每个结点包括了key-value对、哈希值、指向下一个节点的引用。

3. Segment

ConcurrentHashMap最重要的概念就是Segment了,它是一个有锁功能的(继承了ReentrantLock)哈希表,ConcurrentHashMap正是由Segment数组组成的数据结构。

看下Segment的类声明:

static final class Segment extends ReentrantLock implements Serializable

Segment通过继承ReentrantLock拥有了锁的功能。

接着看下Segment的几个成员变量:


//获取锁失败后的尝试次数,和机器可用的cpu核数量有关
static final int MAX_SCAN_RETRIES =
  Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

//哈希表,一个segment对应一个哈希表
transient volatile HashEntry[] table;

//哈希表kv元素的个数,注意:ConcurrentHashMap的元素数量是所有segment的元素数量之和
transient int count;

//哈希表改变的次数
transient int modCount;

//哈希表重哈希的阀值,元素数量超过这个值,需要扩充哈希表,否则哈希冲突会增加
transient int threshold;

//加载因子
final float loadFactor;

这几个变量我们之前在学习HashMap的时候基本上都学习过,看下注释就可以了。

看下Segment唯一的一个构造方法:

Segment(float lf, int threshold, HashEntry[] tab) {
    this.loadFactor = lf;
    this.threshold = threshold;
    this.table = tab;
}

Segment没有默认构造方法。

接着看下Segment的put方法:

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    HashEntry node = tryLock() ? null :
    scanAndLockForPut(key, hash, value);
    //到这,一定成功获取锁了。
    //注意,此时node可能为空,也可能不为空。如果为空,接下来put的时候需要创建一个新的结点,如果不为空
    //可以直接使用该节点。

    //返回key对应老的value值
    V oldValue;
    try {
      HashEntry[] tab = table;
      //定位到HashEntry索引
      int index = (tab.length - 1) & hash;
      HashEntry first = entryAt(tab, index);
      for (HashEntry e = first;;) {
        if (e != null) {
          K k;
          //如果key已经存在,更新对应的value,跳出for循环
          if ((k = e.key) == key ||
              (e.hash == hash && key.equals(k))) {
            oldValue = e.value;
            if (!onlyIfAbsent) {
              e.value = value;
              ++modCount;
            }
            break;
          }
          e = e.next;
        }
        else {
          //node不为空,将node插入到链表的头部
          if (node != null)
            node.setNext(first);
          else
            //创建一个新的节点并插入到链表的头部
            node = new HashEntry(hash, key, value, first);
          int c = count + 1;
          //元素数量超过阀值,重哈希
          if (c > threshold && tab.length < MAXIMUM_CAPACITY)
            rehash(node);
          else
            //更新哈希表table处索引index处的值为node,每次插入都是插入到链表的头部
            setEntryAt(tab, index, node);
          ++modCount;
          count = c;
          oldValue = null;
          break;
        }
      }
    } finally {
      //释放锁
      unlock();
    }
    return oldValue;
}

该方法将指定的key-value添加到哈希表,如果key已经存在,更新对应的value值,否则创建一个新的节点,加入到哈希表。基本思路很简单,但是put之前加锁操作比较复杂。

put方法开始的时候,尝试获取锁,如果获取锁不成功,调用scanAndLockForPut方法,这个方法在尝试获取锁失败的情况下,如果key对应的节点存在,返回null,否则为后续put操作创建一个新的节点,该方法返回之前一定成功获取到锁。注意,虽然scanAndLockForPut方法在发现key对应的节点存在的情况下,返回了null,put方法还是会判断该key对应的节点是否存在,如果存在则更新value。

如果插入节点后的元素个数大于threshold,需要对该哈希表重哈希,重哈希后的哈希表容量是原来的2倍。

Segment的重哈希过程做了一个优化,找到该segment的HashEntry链表的某个元素lastIdx,使得从该元素开始到链表末尾的所有元素在新哈希表相同的桶中。这样,只需要将该元素之前的元素一个个的添加到新的链表即可,一定程度上复用了原来同一个槽上的部分节点,看下示意图:

ConcurrentHashMap源码详解_第2张图片

上图中红色节点在新哈希表的位置相同,直接复用这几个节点。红色节点之前的元素需要一个个的添加到新的哈希表中。原理介绍完了,看下代码的实现:

private void rehash(HashEntry node) {
    HashEntry[] oldTable = table;
    int oldCapacity = oldTable.length;
    //新容量扩充为原来的2倍
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    HashEntry[] newTable =
      (HashEntry[]) new HashEntry[newCapacity];
    int sizeMask = newCapacity - 1;
    for (int i = 0; i < oldCapacity ; i++) {
      HashEntry e = oldTable[i];
      if (e != null) {
        HashEntry next = e.next;
        int idx = e.hash & sizeMask;
        //只有一个节点的链表,直接放到新表即可
        if (next == null) 
          newTable[idx] = e;
        else { 
          //这段代码就是找到上图中lastIdx的节点
          HashEntry lastRun = e;
          int lastIdx = idx;
          for (HashEntry last = next;
               last != null;
               last = last.next) {
            int k = last.hash & sizeMask;
            if (k != lastIdx) {
              lastIdx = k;
              lastRun = last;
            }
          }
          //将lastIdx和后面的节点放到新的哈希表
          newTable[lastIdx] = lastRun;
          //lastIdx之前的节点一个个加到新的哈希表
          for (HashEntry p = e; p != lastRun; p = p.next) {
            V v = p.value;
            int h = p.hash;
            int k = h & sizeMask;
            HashEntry n = newTable[k];
            newTable[k] = new HashEntry(h, p.key, v, n);
          }
        }
      }
    }
    //新添加的节点node放到链表头部
    int nodeIndex = node.hash & sizeMask; 
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

接着看下Segment的remove方法:

final V remove(Object key, int hash, Object value) {
    //删除元素之前要获取锁
    if (!tryLock())
      scanAndLock(key, hash);
    V oldValue = null;
    try {
      HashEntry[] tab = table;
      int index = (tab.length - 1) & hash;
      HashEntry e = entryAt(tab, index);
      HashEntry pred = null;
      while (e != null) {
        K k;
        HashEntry next = e.next;
        if ((k = e.key) == key ||
            (e.hash == hash && key.equals(k))) {
          V v = e.value;
          if (value == null || value == v || value.equals(v)) {
            //value为空,只要key相同就删除,value不为空,要比较value和v是否相同
            if (pred == null)
              setEntryAt(tab, index, next);
            else
              pred.setNext(next);
            ++modCount;
            --count;
            oldValue = v;
          }
          break;
        }
        pred = e;
        e = next;
      }
    } finally {
      unlock();
    }
    return oldValue;
}

删除操作很简单,注意一点,当入参value为null,只要key相同就删除,否则需要比较value和当前节点的值是否相同。

删除之前需要获取锁,如果通过tryLock获取锁失败,调用scanAndLock获取锁。scanAndLock通过tryLock尝试获取次数越过MAX_SCAN_RETRIES,则调用lock方法阻塞等待锁,显然,lock方法将引起线程上下文切换,增加额外开销。

Segment还有两个replace的重载方法和一个clear方法,代码逻辑都很简单,不再说明了。

接着看下ConcurrentHashMap的put方法:

4. put

该方法将指定key-value对添加到哈希表中,其中key和value都不能为null,看下源码:

public V put(K key, V value) {
    Segment s;
    //value不能为null
    if (value == null)
      throw new NullPointerException();
    //如果key为null,hash方法会抛出NPE异常
    int hash = hash(key);
    int j = (hash >>> segmentShift) & segmentMask;
    if ((s = (Segment)UNSAFE.getObject         
         (segments, (j << SSHIFT) + SBASE)) == null) 
      s = ensureSegment(j);
    //委托给了segment的put方法
    return s.put(key, hash, value, false);
}

通过hash方法得到键key的哈希值,将该哈希值右移segmentShift位后和segmentMask执行“与”操作,得到key对应的segment索引。

看下segmentShift和segmentMask,构造函数中初始化了这两个值:

this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;

其中ssize是segments数组的大小,它是2的n次方,例如16、32、64等。sshift是数字n的大小,例如4、5、6等。获取key对应的segment段索引时,其实是通过键key哈希码的高sshift位来决定segment索引的。put方法最终委托给了segment的put方法,真正执行添加操作。

5. putAll

putAll将指定的Map添加到该ConcurrentHashMap,看下源码:

public void putAll(Map m) {
    for (Map.Entry e : m.entrySet())
      put(e.getKey(), e.getValue());
}

源码看简单,遍历指定的map,通过put方法一个个的添加。

6. get

该方法取key对应的value值,源码看着比较复杂:

public V get(Object key) {
    Segment s;
    HashEntry[] tab;
    //取键的哈希码
    int h = hash(key);
    //segment的内存偏移
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    if ((s = (Segment)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
      //找到key所在HashEntry链表
      for (HashEntry e = (HashEntry) UNSAFE.getObjectVolatile
           (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
           e != null; e = e.next) {
        K k;
        if ((k = e.key) == key || (e.hash == h && key.equals(k)))
          //找到了,返回对应的value
          return e.value;
      }
    }
    //没有找到,返回null
    return null;
}

7. containsKey

判断是否存在指定的key,看下源码:

public boolean containsKey(Object key) {
    Segment s; 
    HashEntry[] tab;
    //计算key对应的哈希码
    int h = hash(key);
    //计算key对应的segment内存偏移
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    if ((s = (Segment)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
      //找到key所在的桶,然后沿着链表查找
      for (HashEntry e = (HashEntry) UNSAFE.getObjectVolatile
           (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
           e != null; e = e.next) {
        K k;
        //找到了就返回true
        if ((k = e.key) == key || (e.hash == h && key.equals(k)))
          return true;
      }
    }
    return false;
}

containsKey方法逻辑上也很简单,首先通过key的哈希码,找到所在的segment,然后找到该key在该segment所在的桶,在这个桶的链表上查找该key,如果找到,返回true,否则返回false。

8. containsValue

该方法查找ConcurrentHashMap是否存在指定的value,如果存在key对应的value为该值,返回true,否则返回false。注意,该方法无法快速定位到segment和桶,只能整个遍历ConcurrentHashMap并比较value值,因此相对于containsKey,该方法就显得很慢了。

9. size

该方法返回ConcurrentHashMap的key-value对数量。将每个segment的key-value对数量相加,如果相加后发现modCount和上次保存的modCount不一样,说明相加过程中有线程修改了ConcurrentHashMap,为了获取准确的size,需要重试。如果重试次数超过指定的次数,锁住所有的segment,然后再执行相加操作,确保相加过程中没有线程能够修改。

看下源码:

public int size() {
    final Segment[] segments = this.segments;
    int size;
    boolean overflow;
    long sum;      
    long last = 0L;  
    //重试的次数
    int retries = -1;
    try {
      //外循环,直到本次的sum和上次的sum相同为止,本次sum和上次sum相同,
      //说明计算过程中,没有线程修改(改变modCount),计算的元素个数一定是准确的
      for (;;) {
        //重试超过指定次数(默认为2),将所有的segment都锁住,防止计算size
        //过程被线程修改,元素个数计算完成后再解锁
        if (retries++ == RETRIES_BEFORE_LOCK) {
          for (int j = 0; j < segments.length; ++j)
            ensureSegment(j).lock();
        }
        //保存所有segment的修改次数
        sum = 0L;
        //元素个数
        size = 0;
        overflow = false;
        //遍历所有的segment,累加每个segment的元素个数
        for (int j = 0; j < segments.length; ++j) {
          Segment seg = segmentAt(segments, j);
          if (seg != null) {
            sum += seg.modCount;
            int c = seg.count;
            if (c < 0 || (size += c) < 0)
              overflow = true;
          }
        }
        //相等,说明计算过程中没有线程进行修改操作,计算结果是正确的,跳出外循环,否则继续重试计算
        if (sum == last)
          break;
        last = sum;
      }
    } finally {
      //重试超过指定次数,将所有的segment解锁
      if (retries > RETRIES_BEFORE_LOCK) {
        for (int j = 0; j < segments.length; ++j)
          segmentAt(segments, j).unlock();
      }
    }
    //size溢出,返回Inter.MAX_VALUE
    return overflow ? Integer.MAX_VALUE : size;
}

10. isEmpty

该方法判断ConcurrentHashMap的元素个数是否为0。遍历每个segment,若当前segment的元素个数不为0,返回false。否则将每个segment的modCount累加,结果计为sum。若sum等于0,说明遍历过程中元素个数未变,返回true。否则继续第二次遍历segment,同样的道理,遍历过程中若发现当前segment的元素个数不为0,返回false,否则将sum减去当前segment的modCount,若遍历结束后sum不等于0,说明第二次遍历过程,元素个数有改变,返回false。

看下源码:

public boolean isEmpty() {
    long sum = 0L;
    final Segment[] segments = this.segments;
    //第一次遍历segments
    for (int j = 0; j < segments.length; ++j) {
      Segment seg = segmentAt(segments, j);
      if (seg != null) {
        //若当前segment的元素个数不等于0,说明不为空,返回false
        if (seg.count != 0)
          return false;
        //累加modCount
        sum += seg.modCount;
      }
    }
    if (sum != 0L) {
      //第二次遍历
      for (int j = 0; j < segments.length; ++j) {
        Segment seg = segmentAt(segments, j);
        if (seg != null) {
          //同样的道理,若当前segment的元素个数不等于0,说明不为空,返回false
          if (seg.count != 0)
            return false;
          //减去当前的modCount
          sum -= seg.modCount;
        }
      }
      //说明第二次遍历过程中元素个数有改变,认为不为空
      if (sum != 0L)
        return false;
    }
    return true;
}

为何要进行两次遍历操作?正常情况下应该将所有的segment锁住,遍历所有的segment,累加元素个数,判断是否为0,然后再解锁。但是为了减少加锁对性能的影响,采用两次操作来判断是否为空。这种方法可以避免加锁的性能影响,但是也会失去100%的正确性,某些情况下,该方法返回true并不真正意味着该Map为空。例如第一次遍历过程中,当遍历到第二个segment,有其他线程已经往第一个segment添加了元素,但是我们遍历第一个segment的时候,该segment的modCount为0,第一次遍历结束后我们可能得到sum等于0,返回了true。但是第一个segment的的元素个数已经不为0了。这也是ConcurrentHashMap的弱一致性表现,为了性能,这种折衷和妥协也是可以理解的。

11. 迭代器

ConcurrentHashMap的迭代器实现了Iterator接口,并继承了内部抽象类HashIterator。迭代器的功能都委托给了HashIterator。看下HashIterator源码:

abstract class HashIterator {
    int nextSegmentIndex;
    int nextTableIndex;
    HashEntry[] currentTable;
    HashEntry nextEntry;
    HashEntry lastReturned;

    HashIterator() {
      nextSegmentIndex = segments.length - 1;
      nextTableIndex = -1;
      advance();
    }

    //将nextEntry指向非空的HashEntry节点
    final void advance() {
      for (;;) {
        if (nextTableIndex >= 0) {
          if ((nextEntry = entryAt(currentTable,
                                   nextTableIndex--)) != null)
            break;
        }
        else if (nextSegmentIndex >= 0) {
          Segment seg = segmentAt(segments, nextSegmentIndex--);
          if (seg != null && (currentTable = seg.table) != null)
            nextTableIndex = currentTable.length - 1;
        }
        else
          break;
      }
    }
    //返回下一个节点,如果返回节点的下一个节点为空,需要调用advance找到下一个非空的节点
    final HashEntry nextEntry() {
      HashEntry e = nextEntry;
      if (e == null)
        throw new NoSuchElementException();
      lastReturned = e;
      //下一个节点为空,找到下一个非空的节点
      if ((nextEntry = e.next) == null)
        advance();
      return e;
    }

    public final boolean hasNext() { return nextEntry != null; }
    public final boolean hasMoreElements() { return nextEntry != null; }

    //删除操作委托给了外部类的remove方法,注意删除节点后需要将lastReturned设置为null
    public final void remove() {
      if (lastReturned == null)
        throw new IllegalStateException();
      ConcurrentHashMap.this.remove(lastReturned.key);
      lastReturned = null;
    }
}

这里需要注意的是,构造函数和nextEntry方法中,需要确保下一个节点非空,如果为空,说明迭代器遍历结束了。

ConcurrentHashMap的迭代器,如KeyIterator、ValueIterator、EntryIterator等都继承了HashEntry,并委托给了HashIterator,方法都很简单,不再说明。

参考源码:jdk1.7.0_79

你可能感兴趣的:(jdk集合)