【JavaEE】ThreadLocal源码解析

在之前的博文【JavaEE】关于ThreadLocal和模态框的关闭中,我们曾经用到过ThreadLocal,当时对于ThreadLocal的理解是我们可以将两个彼此毫无关系的线程之间建立关系。但是这到底是怎么实现的?现在让我们来对它的源码进行一下探究。
首先,可以看到,ThreadLocal类是一个泛型类

public class ThreadLocal<T> {
	……
}

所以,我们在使用它的时候,必须要给<>里面增加一个类型。
我们之前使用它的时候,是在某一个线程中通过无参构造,然后覆盖initialValue()方法,再在另一个线程里面通过get()方法,得到initialValue()方法返回的实例对象。那么我们进而来看一下这两个方法的源码。
initialValue()方法:

	protected T initialValue() {
        return null;
    }

这个方法是被protected所修饰的,那么它自然是可以被继承的,而它的返回值类型是一个泛型,也恰好是ThreadLocal这个泛型类的类型,那么我就可以在外面定义任意的类型,然后将它作为泛型,然后通过initialValue()方法,将这个类型的实例对象进行返回。
再看get()方法:

	public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

可以看到,get()方法的返回值类型同样是一个泛型,然而,它return的却是一个setInitialValue(),并且,没有任何的参数传进去,那么我们就可以先"忽略"(后面我们会对这些代码进行详细分析)中间的代码,直接看setInitialValue()方法:

    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

不用说这个方法的返回值必定是泛型类型。它返回的value是和initialValue()方法有关,并且又没有传参,value的值在中间的代码里面也没有任何的修改,initialValue()方法在之前我们已经看到过它的源代码,它就是返回一个泛型类型。所以,到这里,我们可以明白,为什么我们之前对于ThreadLocal的使用,通过initialValue()方法和get()方法可以得到一个泛型类的对象,原因正是因为我们在new的时候通过覆盖initialValue()方法返回了它的对象,然后在get()方法调用了initialValue()的返回值,也等于是直接进行了返回,所以我们可以得到这个对象。

事实上,我们在get()方法中,可以看到有两个if判断,如果全满足,会返回一个结果。我们之前的做法,其实是get()方法中的一种情况,这种情况其实对应了我们没有对ThreadLocal这个类进行过任何的 set 等操作,所以,队员两个if都是不满足的,直接返回了我们最初initialValue()方法的结果。那接下来,我们从set方法开始进行分析,之后再重新分析一下get()方法的另外一种情况。
set方法(单参):

    public void set(T value) {
    	//	通过native方法,获得当前的线程
        Thread t = Thread.currentThread();
        //	再得到该Thread对象的ThreadLocalMap成员 map
        //	如果从来没有进行过set,那结果肯定是null
        ThreadLocalMap map = getMap(t);
        //	判断map是否为null   
        if (map != null)
        	//	若不为空,需要调用map的set()方法
            map.set(this, value);
        else
        	//	若map为空,需要先执行createMap()方法,初始化ThreadLocalMap 
            createMap(t, value);
    }

上面提到了很多之前没没提到的内容,现在来一一讲解:
Thread类,

public
class Thread implements Runnable {
	//	其中,众多成员中,这个成员和我们的ThreadLocal是很相关的。
	//	它的类型是ThreadLocal.ThreadLocalMap类型,可见ThreadLocalMap是ThreadLocal的内部类
	ThreadLocal.ThreadLocalMap threadLocals = null;
	……
}

我们再来看一下ThreadLocalMap这个内部类:

static class ThreadLocalMap {
	//	Entry继承WeakReference,并且用ThreadLocal作为key
	//	如果它的get()返回值为空,那么表明对象已经被回收了,也就是可以重复对这个空间赋值
	//	关于WeakReference类的详细解析,请看我转载的这篇博文(https://www.cnblogs.com/zjj1996/p/9140385.html)
	static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
    }
    //	数组的初始容量
  	private static final int INITIAL_CAPACITY = 16;
	//	一个节点数组
  	private Entry[] table;
  	//	存放的节点个数
  	private int size = 0;
  	//	阈值
  	private int threshold;

现在可以明确的是,对于Thread类,每一个Thread类都会有一个ThreadLocalMap ,来存放多个线程本地变量。
以上是这个类的成员,可以看到,这些成员与HashMap里面的成员有很多是相同的(但是这里没有写出来负载因子),那么它的存储结构也应该有很多相似地方,我们继续往下看。
getMap()方法:

  ThreadLocalMap getMap(Thread t) {
     	//	返回的是t线程的ThreadLocalMap 
        return t.threadLocals;
   }

createMap()方法:

    void createMap(Thread t, T firstValue) {
    	//	调用了ThreadLocalMap的构造方法
    	//	把自己和一个泛型对象传进去
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    
  	ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
  			//	因为是第一次创建ThreadLocalMap对象,所以要先初始化table数组
            table = new Entry[INITIAL_CAPACITY];
            //	这一步操作和HashMap根据哈希值确定数组下标有些像,
            //	都是和(长度-1)进行相与,其结果作为数组的下标值
            //	而这里的threadLocalHashCode,其实是调用了AtomicInteger类的getAndAdd()方法
			//	AtomicInteger提供原子操作来进行Integer的使用,因此十分适合高并发情况下的使用
			//	https://www.cnblogs.com/zhaoyan001/p/8885360.html
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            //	将键和值放到一个Entry节点中,并将节点放到数组中
            table[i] = new Entry(firstKey, firstValue);
            //	因为只放了一个,所以当前长度为1;
            size = 1;
            //	设置阈值
            setThreshold(INITIAL_CAPACITY);
    }

	private void setThreshold(int len) {
			//	在HashMap中,阈值等于加载因子*capacity(数组容量)
			//	那么,这里的加载因子就是 2/3(HashMap中是0.75)
            threshold = len * 2 / 3;
    }

这时,对于第一个泛型类型的value就已经set好了;那如果设置第二个value的时候,因为ThreadLocalMap 已经初始化过了,将会执行map.set(this, value);这个方法。现在来分析一下这个方法:

		//	这个方法是在ThreadLocalMap 里面的
		private void set(ThreadLocal<?> key, Object value) {
			//	下面这个英文注释已经表明,set方法并没有像get()方法那样快速
			//	后面我们会对get方法进行分析
            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.
		
			//	给tab赋值
            Entry[] tab = table;
            //	获取长度
            int len = tab.length;
            //	根据哈希值计算下标的位置
            int i = key.threadLocalHashCode & (len-1);

			//	开始进行遍历
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
				//	如果键和值都相等,说明存在了该对象,直接进行更新
                if (k == key) {
                    e.value = value;
                    return;
                }

 				//	如果键为null,说明被回收了
				//	这个时候说明改table[i]可以重新使用,
				//	用新的key-value将其替换,并删除其他无效的entry
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
			//	只要循环结束,那就表明在当前下标的值为null,直接重新生成一个节点并插入
            tab[i] = new Entry(key, value);
            //	当然节点数要增加
            int sz = ++size;
            //	开始进行清除,因为有的数组空间,它自己不为null,但是他的get()方法为null,
            //	也就是上面的if(k == null),将这些空间进行清理
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
            	//	如果没有清理,检测当前存储的值是否超过阈值,进行扩容
                rehash();
        }

上面代码有写道nextIndex()方法,replaceStaleEntry(),cleanSomeSlots(),rehash()方法后面我们一一解析。先看nextIndex()方法,它和prevIndex()方法是对应的。

private static int nextIndex(int i, int len) {
 			//	判断当前下标加一(也就是后一个下标)是否超过数组长度,
 			//	没有超过就返回,超过就从下标为0开始
            return ((i + 1 < len) ? i + 1 : 0);
}
private static int prevIndex(int i, int len) {
			//	判断当前下标减一(也就是前一个下标)是否小于0,
			//	没有,就返回,小于了就从数组最末尾开始
            return ((i - 1 >= 0) ? i - 1 : len - 1);
}

其实,这两个函数,目的就是为了解决hash冲突,然后对数组进行了一个遍历,包括从前遍历和从后遍历两种。这也就是所谓的线性探测法。可以看到,这种方法的缺点就是假如数组特别长,其中为null的数组特别少,那么遍历起来是费时间的。画个图表示:
【JavaEE】ThreadLocal源码解析_第1张图片再看replaceStaleEntry()方法:

	//	替换不新鲜的Entry
	//	这个函数,源码已经给了英文解释
	private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            //	要抹去的位置先赋值为当前值
            int slotToExpunge = staleSlot;
            //	从当前开始,从前遍历,直到找到第一个null数组,
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            //	从当前开始,从后遍历,
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                //	如果找到了key,进行赋值,
                //	并且将当前值与最开始的tab[staleSlot]值进行替换,         
                //	在这里可能有人会有个疑问,
                //	为什么key相等,不在原来hash值对应的下标存放,反而要到这个位置来存放?
                //	原因很简单,在set方法中,我们是从hash值计算得到对应的下标开始遍历的
                //	假如,那个下标有值了,那么不得不更换位置进行存储,所以找到了一块已“失效”的空间
                //	这也是为了避免浪费,将“失效”的空间利用起来
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    //	如果向前查找没有找到(e.get() == null)的节点,         
                    if (slotToExpunge == staleSlot)
                    	//	令要抹去的为当前i
                        slotToExpunge = i;
                    //	进行清除工作
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                //	如果当前的节点(e.get() == null),
                //	并且向前查找没有找到(e.get() == null)的节点
                if (k == null && slotToExpunge == staleSlot)
                	//	那么令要抹去的位置为i;
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            //	执行到这,表明,没有找到,
            //	key之前不存在table中
            tab[staleSlot].value = null;
            //	那就在当前直接进行替换
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            //	最开始slotToExpunge与staleSlot 的结果是相等的,进行了两次遍历
            //	若slotToExpunge != staleSlot,说明存在其他的无效entry需要进行清理。
            if (slotToExpunge != staleSlot)
            	//	开始进行清理
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

然后我们进行cleanSomeSlots()函数的解析:

		//	清理工作
		private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
            	//	进行遍历
                i = nextIndex(i, len);
                Entry e = tab[i];
                //	表明当前节点是“无效的”,要进行清楚
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //	进行清除
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);//	从这里可见,n表示了循环的次数	
            //	返回是否进行了清理
            return removed;
		}
		
		private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            //	对传入过来的数组元素进行清楚
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            //	存入的节点的长度也要跟着减
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            //	继续从当前清理过的下一个元素开始,进行遍历	
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    //	表明是“无效”节点,需要进行清除
                    e.value = null;
                    tab[i] = null;
                    //	存入的节点的长度也要跟着减
                    size--;
                } else {
                	//	检测一下当前的key的hash值与数组长度-1的结果h是否和当前元素的下标i相等
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                    	//	不相等,立马将i设置为null	
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        //	开始寻找这个h,如果数组为h的不等于null(表明已经存放了别的值)
                        //	那就继续遍历,找到为null的下标
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        //	将下标存入	
                        tab[h] = e;
                    }
                }
            }
            //	返回下一个为null的solt的下标。
            return i;
       }

rehash()方法源码:

		 private void rehash() {
		 	//	它先调用了expungeStaleEntries()方法
		 	//	先对数组作清理工作
		     expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            //	因为阈值 threshold = len * 2 /3
            //	经过换算 threshold * 3 / 4 = len / 2
            //	在HashMap中,当size = threshold 的时候就要进行扩容,
            //	而这里应该是把这个阈值变得更小了,
            if (size >= threshold - threshold / 4)
            	//	开始进行扩容
                resize();
        }
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                //	对“失效”的元素进行清理
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }        
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            //	新数组的长度也是扩大了两倍
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
			//	对原数组进行遍历
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    //	如果当前的元素“无效”
                    if (k == null) {
                    	//	直接将原数组的值设置为null
                    	//	可以看到,在源代码的注释中也写出了注释: 帮助垃圾回收机制
                        e.value = null; // Help the GC
                    } else {
                    	//	执行到这里,说明数组中存的是有效元素
                    	//	重新计算下标的位置,
                        int h = k.threadLocalHashCode & (newLen - 1);
                        //	通过下标h,进行检测,是否存了数据,
                        while (newTab[h] != null)
                        	//	如果存了,那就继续寻找为null的
                            h = nextIndex(h, newLen);
                        //	找到后,进行数组的赋值
                        newTab[h] = e;
                        //	新数组中的有效节点个数加一
                        count++;
                        //	继续循环遍历
                    }
                }
            }
			//	根据当前新数组的长度,更改阈值
            setThreshold(newLen);
            //	设置新数组的有效节点数
            size = count;
            //	将新数组赋值给成员table
            table = newTab;
        }
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

这时,我们可以再来看get方法:

	public T get() {
        Thread t = Thread.currentThread();
        //	获取在对应线程中的ThreadLocalMap实例
        ThreadLocalMap map = getMap(t);
        //	检测是否为null
        if (map != null) {
        	//	获取该ThreadLocal所对应的Entry实例
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    private Entry getEntry(ThreadLocal<?> key) {
    	//	通过传过来的key,计算出对应的数组下标
        int i = key.threadLocalHashCode & (table.length - 1);
        //	取得该节点
        Entry e = table[i];
        //	再进行一次检测,确保取得的结果正确
        if (e != null && e.get() == key)        	
            return e;
        else
        	//	不正确,继续在数组的别的下标查找
            return getEntryAfterMiss(key, i, e);
    }
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;
		//	进行遍历查找
        while (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == key)
                return e;
            if (k == null)
            	//	找到一个“失效”节点,直接进行清除
                expungeStaleEntry(i);
            else
            	//	下标向下一位移动
                i = nextIndex(i, len);
            //	继续进行遍历
            e = tab[i];
        }
        //	表明没有找到,返回null
        return null;
    }

最后看remove()方法:

		//	移除某个节点
		private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            //	计算出节点下标
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                	//	如果key一样,不管是不是“有效”,全部进行清除
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

最后看remove就相对简单很多了。
综合来说,上面的很多方法其实都是ThreadLocalMap的内部的方法,因为一个Thread对应一个ThreadLocalMap,而ThreadLocalMap是ThreadLoacl的一个内部类。ThreadLocal的好多方法的实现正是调用ThreadLocalMap,因为他们彼此之间都是唯一的。
存储结构,它采用的是线性探测法解决hash冲突。对于它的存取过程而言,和HashMap一样,会先进行,数组的定位,然后再具体定位,定位之前也是先进行判断,如果遇到“无效”的,直接进行删除。而且,它在很多时候,不论是存还是取,都不忘记去及时清理“无效”节点,这是我编程的最大的收获。

你可能感兴趣的:(ThreadLocal源码,线性探测法)