ConcurrentHashMap原理分析

相关的HashMap

java里面有HashMap、HashTable 和 ConcurrentHashMap。其中HashTable 和 ConcurrentHashMap是线程安全的,但是原理不一样。HashTable使用synchronized来锁住整张Hash表来实现线程安全,即每次锁住整张表让线程独占,这样会造成并发量很低。

初识ConcurrentHashMap

ConcurrentHashMap允许多个修改操作并发进行,其关键在于使用了锁分离技术。它使用了多个锁来控制对hash表的不同部分进行的修改。ConcurrentHashMap内部使用段(Segment)来表示这些不同的部分,每个段其实就是一个小的Hashtable,它们有自己的锁。只要多个修改操作发生在不同的段上,它们就可以并发进行。
有些方法需要跨段,比如size()和containsValue(),它们可能需要锁定整个表而而不仅仅是某个段,这需要按顺序锁定所有段,操作完毕后,又按顺序释放所有段的锁。这里“按顺序”是很重要的,否则极有可能出现死锁。

ConcurrentHashMap原理

本文章所有分析基于jdk1.7

原理图

ConcurrentHashMap原理分析_第1张图片
类图
ConcurrentHashMap原理分析_第2张图片
内部结构图

原理分析

ConcurrentHashMap使用分段锁技术,将数据分成一段一段的存储,然后给每一段数据配一把锁,当一个线程占用锁访问其中一个段数据的时候,其他段的数据也能被其他线程访问,能够实现真正的并发访问

数据结构

  • Segment
    从内部结构图中可以看到,ConcurrentHashMap内部分为很多个Segment,每一个Segment拥有一把锁,然后每个Segment(继承ReentrantLock)
    Segment的声明如下
 static final class Segment extends ReentrantLock implements Serializable 

Segment继承了ReentrantLock,表明每个segment都可以当做一个锁。segment里面是HashEntry的数组(ReentrantLock的使用可以参考文章)这样对每个segment中的数据需要同步操作的话都是使用每个segment容器对象自身的锁来实现。只有对全局需要改变时锁定的是所有的segment。

  • HashEntry
 static final class HashEntry {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry next;
}

常用操作

1 初始化

先看一下初始化代码

/**
     * Creates a new, empty map with the specified initial
     * capacity, load factor and concurrency level.
     *
     * @param initialCapacity the initial capacity. The implementation
     * performs internal sizing to accommodate this many elements.
     * @param loadFactor  the load factor threshold, used to control resizing.
     * Resizing may be performed when the average number of elements per
     * bin exceeds this threshold.
     * @param concurrencyLevel the estimated number of concurrently
     * updating threads. The implementation performs internal sizing
     * to try to accommodate this many threads.
     * @throws IllegalArgumentException if the initial capacity is
     * negative or the load factor or concurrencyLevel are
     * nonpositive.
     */
    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        int sshift = 0;
        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // create segments and segments[0]
        Segment s0 =
            new Segment(loadFactor, (int)(cap * loadFactor),
                             (HashEntry[])new HashEntry[cap]);
        Segment[] ss = (Segment[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }

传入的参数有initialCapacity,loadFactor,concurrencyLevel这三个。
initialCapacity表示新创建的这个ConcurrentHashMap的初始容量,也就是上面的结构图中的Entry数量。默认值为static final int DEFAULT_INITIAL_CAPACITY = 16;
loadFactor表示负载因子,就是当ConcurrentHashMap中的元素个数大于loadFactor * 最大容量时就需要rehash,扩容。默认值为static final float DEFAULT_LOAD_FACTOR = 0.75f;
concurrencyLevel 表示并发级别,这个值用来确定Segment的个数,Segment的个数是大于等于concurrencyLevel的第一个2的n次方的数。比如,如果concurrencyLevel为12,13,14,15,16这些数,则Segment的数目为16(2的4次方)。默认值为static final int DEFAULT_CONCURRENCY_LEVEL = 16;。理想情况下ConcurrentHashMap的真正的并发访问量能够达到concurrencyLevel,因为有concurrencyLevel个Segment,假如有concurrencyLevel个线程需要访问Map,并且需要访问的数据都恰好分别落在不同的Segment中,则这些线程能够无竞争地自由访问(因为他们不需要竞争同一把锁),达到同时访问的效果。这也是为什么这个参数起名为“并发级别”的原因。

初始化的一些动作:

  • 验证参数的合法性,如果不合法,直接抛出异常。
  • concurrencyLevel也就是Segment的个数不能超过规定的最大Segment的个数,默认值为static final int MAX_SEGMENTS = 1 << 16;,如果超过这个值,设置为这个值。
  • 然后使用循环找到大于等于concurrencyLevel的第一个2的n次方的数ssize,这个数就是Segment数组的大小,并记录一共向左按位移动的次数sshift,并令segmentShift = 32 - sshift,并且segmentMask的值等于ssize - 1,segmentMask的各个二进制位都为1,目的是之后可以通过key的hash值与这个值做&运算确定Segment的索引。
  • 检查给的容量值是否大于允许的最大容量值,如果大于该值,设置为该值。最大容量值为static final int MAXIMUM_CAPACITY = 1 << 30;。
    然后计算每个Segment平均应该放置多少个元素,这个值c是向上取整的值。比如初始容量为15,Segment个数为4,则每个Segment平均需要放置4个元素。
  • 最后创建一个Segment实例,将其当做Segment数组的第一个元素。

2、put操作

/**
     * Maps the specified key to the specified value in this table.
     * Neither the key nor the value can be null.
     *
     * 

The value can be retrieved by calling the get method * with a key that is equal to the original key. * * @param key key with which the specified value is to be associated * @param value value to be associated with the specified key * @return the previous value associated with key, or * null if there was no mapping for key * @throws NullPointerException if the specified key or value is null */ @SuppressWarnings("unchecked") public V put(K key, V value) { Segment s; if (value == null) throw new NullPointerException(); int hash = hash(key); int j = (hash >>> segmentShift) & segmentMask; if ((s = (Segment)UNSAFE.getObject // nonvolatile; recheck (segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment s = ensureSegment(j); return s.put(key, hash, value, false); }

操作步骤如下:

  • 判断value是否为null,如果为null,直接抛出异常。
  • key通过一次hash运算得到一个hash值。(这个hash运算下文详说)
  • 将得到hash值向右按位移动segmentShift位,然后再与segmentMask做&运算得到segment的索引j。
    在初始化的时候我们说过segmentShift的值等于32-sshift,例如concurrencyLevel等于16,则sshift等于4,则segmentShift为28,segmentMask为15,对应的二进制为 0000 0000 0000 0000 0000 0000 0000 1111。hash值是一个32位的整数,将其向右移动28位就变成这个样子:
    0000 0000 0000 0000 0000 0000 0000 xxxx,然后再用这个值与segmentMask做&运算,也就是取最后四位的值。这个值确定Segment的索引。
  • 使用Unsafe的方式从Segment数组中获取该索引对应的Segment对象。
  • 向这个Segment对象中put值,这个put操作也基本是一样的步骤(通过&运算获取HashEntry的索引,然后set)。
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            HashEntry node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                HashEntry[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry first = entryAt(tab, index);
                for (HashEntry e = first;;) {
                    if (e != null) {
                        K k;
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    else {
                        if (node != null)
                            node.setNext(first);
                        else
                            node = new HashEntry(hash, key, value, first);
                        int c = count + 1;
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            rehash(node);
                        else
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                unlock();
            }
            return oldValue;
        }

3、get操作

/**
     * Returns the value to which the specified key is mapped,
     * or {@code null} if this map contains no mapping for the key.
     *
     * 

More formally, if this map contains a mapping from a key * {@code k} to a value {@code v} such that {@code key.equals(k)}, * then this method returns {@code v}; otherwise it returns * {@code null}. (There can be at most one such mapping.) * * @throws NullPointerException if the specified key is null */ public V get(Object key) { Segment s; // manually integrate access methods to reduce overhead HashEntry[] tab; int h = hash(key); long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; if ((s = (Segment)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) { for (HashEntry e = (HashEntry) UNSAFE.getObjectVolatile (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); e != null; e = e.next) { K k; if ((k = e.key) == key || (e.hash == h && key.equals(k))) return e.value; } } return null; }

操作步骤为:

  • 和put操作一样,先通过key进行两次hash确定应该去哪个Segment中取数据。
  • 使用Unsafe获取对应的Segment,然后再进行一次&运算得到HashEntry链表的位置,然后从链表头开始遍历整个链表(因为Hash可能会有碰撞,所以用一个链表保存),如果找到对应的key,则返回对应的value值,如果链表遍历完都没有找到对应的key,则说明Map中不包含该key,返回null。
  • 值得注意的是,get操作是不需要加锁的(如果value为null,会调用readValueUnderLock,只有这个步骤会加锁),通过前面提到的volatile和final来确保数据安全。

4、 size操作

/**
     * Returns the number of key-value mappings in this map.  If the
     * map contains more than Integer.MAX_VALUE elements, returns
     * Integer.MAX_VALUE.
     *
     * @return the number of key-value mappings in this map
     */
    public int size() {
        // Try a few times to get accurate count. On failure due to
        // continuous async changes in table, resort to locking.
        final Segment[] segments = this.segments;
        int size;
        boolean overflow; // true if size overflows 32 bits
        long sum;         // sum of modCounts
        long last = 0L;   // previous sum
        int retries = -1; // first iteration isn't retry
        try {
            for (;;) {
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); // force creation
                }
                sum = 0L;
                size = 0;
                overflow = false;
                for (int j = 0; j < segments.length; ++j) {
                    Segment seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                if (sum == last)
                    break;
                last = sum;
            }
        } finally {
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

size操作与put和get操作最大的区别在于,size操作需要遍历所有的Segment才能算出整个Map的大小,而put和get都只关心一个Segment。假设我们当前遍历的Segment为SA,那么在遍历SA过程中其他的Segment比如SB可能会被修改,于是这一次运算出来的size值可能并不是Map当前的真正大小。所以一个比较简单的办法就是计算Map大小的时候所有的Segment都Lock住,不能更新(包含put,remove等等)数据,计算完之后再Unlock。这是普通人能够想到的方案,但是牛逼的作者还有一个更好的Idea:先给3次机会,不lock所有的Segment,遍历所有Segment,累加各个Segment的大小得到整个Map的大小,如果某相邻的两次计算获取的所有Segment的更新的次数(每个Segment都有一个modCount变量,这个变量在Segment中的Entry被修改时会加一,通过这个值可以得到每个Segment的更新操作的次数)是一样的,说明计算过程中没有更新操作,则直接返回这个值。如果这三次不加锁的计算过程中Map的更新次数有变化,则之后的计算先对所有的Segment加锁,再遍历所有Segment计算Map大小,最后再解锁所有Segment

举个例子

一个Map有4个Segment,标记为S1,S2,S3,S4,现在我们要获取Map的size。计算过程是这样的:第一次计算,不对S1,S2,S3,S4加锁,遍历所有的Segment,假设每个Segment的大小分别为1,2,3,4,更新操作次数分别为:2,2,3,1,则这次计算可以得到Map的总大小为1+2+3+4=10,总共更新操作次数为2+2+3+1=8;第二次计算,不对S1,S2,S3,S4加锁,遍历所有Segment,假设这次每个Segment的大小变成了2,2,3,4,更新次数分别为3,2,3,1,因为两次计算得到的Map更新次数不一致(第一次是8,第二次是9)则可以断定这段时间Map数据被更新,则此时应该再试一次;第三次计算,不对S1,S2,S3,S4加锁,遍历所有Segment,假设每个Segment的更新操作次数还是为3,2,3,1,则因为第二次计算和第三次计算得到的Map的更新操作的次数是一致的,就能说明第二次计算和第三次计算这段时间内Map数据没有被更新,此时可以直接返回第三次计算得到的Map的大小。最坏的情况:第三次计算得到的数据更新次数和第二次也不一样,则只能先对所有Segment加锁再计算最后解锁。

注意事项

  • ConcurrentHashMap中的key和value值都不能为null,HashMap中key可以为null,HashTable中key不能为null。
  • ConcurrentHashMap是线程安全的类并不能保证使用了ConcurrentHashMap的操作都是线程安全的!
  • ConcurrentHashMap的get操作不需要加锁,put操作需要加锁

你可能感兴趣的:(ConcurrentHashMap原理分析)