ThreadLocal解析

前言

我们都知道ThreadLocal用于为每个线程存储自己的变量值,起到线程间隔离的作用,那么它到底是怎么运行的呢,让我们通过一段demo来进行一下源码分析。

    public static void main(String[] args) {

        ThreadLocal sThreadLocal = new ThreadLocal();
        new Thread(()->{sThreadLocal.set(1);System.out.println("线程1的threadlocal值:"+sThreadLocal.get());}).start();
        new Thread(()->{sThreadLocal.set(2);System.out.println("线程2的threadlocal值:"+sThreadLocal.get());}).start();

    }

输出结果:

线程1的threadlocal值:1
线程2的threadlocal值:2

源码解析

set方法

首先来看一下set方法做了什么

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

这里调用了getMap(t)方法,来看一下

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

可以看到返回了当前线程的threadLocals属性,当该属性不为空时调用其对应的set方法,否则调用createMap方法进行初始化,首先来看一下createMap方法

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

这里主要做的事情是初始化当前线程的threadLocals,来看一下构造方法

        ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

这里首先创建了一个Entry类型的数组,数组大小为INITIAL_CAPACITY的值16,EntryThreadLocal的一个内部类,定义为

        static class Entry extends WeakReference> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
        }

该类继承了WeakReference,因此很明显是一种弱引用的方式,这里其实存在一个潜在的内存泄漏问题,那就是key因为弱引用的关系回收了,但该Entry对象由于仍可能被ThreadLocalMap对象强引用而无法释放,这样该Entry就变成了一个“脏对象”,为此代码里在其他地方对这个问题进行了优化,后面会讲到。

i是数组中的下标,通过当前线程的threadLocalHashCode计算得来,而threadLocalHashCode的计算过程如下:

private final int threadLocalHashCode = nextHashCode();
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

这里的nextHashCode定义如下

    private static AtomicInteger nextHashCode =
        new AtomicInteger();

所以threadLocalHashCode实质是一个以指定步长进行累加的累加器,该步长能较好的将连续的线程ID散列到2的幂次方的数组中。另外需要说明的是,传入的Entry的key值是当前ThreadLocal对象,也就是说这个ThreadLocal对象是被弱引用的对象,如果没有别的地方对其进行了强引用,一旦触发gc该对象就会被回收。

看完createMap方法初始化map后,来看set方法

        private void set(ThreadLocal key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {//遍历Entry不为空的节点
                ThreadLocal k = e.get();

                if (k == key) { //若该Entry的key为当前的ThreadLocal对象
                    e.value = value;
                    return;
                }

                if (k == null) { //若该ThreadLocal对象已被回收
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);//遍历到Entry空的节点则创建
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

由上述代码看到,这里主要做的是在一个for循环中遍历寻找Entry不为空的节点,一旦获取到就填入新的Entry值,更新数组size并根据阈值判断是否执行rehash()方法更新数组。

而当遍历到的Entry为非空节点时,会有以下操作:若该Entry的key为当前的ThreadLocal对象时,直接赋值value;若当获取到的Entry为脏对象时,会调用replaceStaleEntry(key, value, i)方法进行清理。

清理方法

这里有几个方法值得我们具体看一下,首先是cleanSomeSlots(i, sz)

        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);//找到脏entry并清除掉
                }
            } while ( (n >>>= 1) != 0);//通过n控制循环次数
            return removed;
        }

该方法用来遍历清除脏Entry,一旦遍历过程中发现了脏Entry,则会调用expungeStaleEntry(i)方法清除掉,并且重置n增加遍历次数。那么expungeStaleEntry(i)做了什么呢

        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

可以看到清除脏Entry的方式其实很简单,就是将该Entry位置设为null,这样一来失去了强引用的脏Entry就会被gc回收。另外可以看到的是,expungeStaleEntry(i)方法清除了i位置的脏Entry后,并不会停下,而是会继续遍历下一个位置清除脏Entry

接着看一下replaceStaleEntry(key, value, i)方法

        private void replaceStaleEntry(ThreadLocal key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;//向前找到第一个脏Entry

            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();

                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    //如果在查找过程中还未发现脏Entry,那么就以当前位置作为清除的起点
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }
                   //如果向前未搜索到脏Entry,而在查找过程遇到脏Entry的话,后面就以此时这个位置作为起点执行清除
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // 没有发现对应的key,则在该脏位置创建新Entry
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            //清除剩余脏Entry
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

该方法首先前向搜寻脏Entry记录为slotToExpunge,接着从staleSlot位置开始后向搜索,如果在查找过程中未发现脏Entry,且存在当前的key,那么赋值value,并且以当前位置staleSlot作为清除的起点;若for循环结束仍未找到对应的key,则在staleSlot位置创建新的Entry节点,并从slotToExpunge位置开始清除剩余的脏Entry。

get方法

看完了ThreadLocal的set方法,接着来看看其get方法

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

这里可以看到,首先通过getMap方法获取当前线程的threadLocals,如果该map不为空,以当前ThreadLocal对象做为key取出对应的Entry得到value值。若没有顺利取得value值,则会执行setInitialValue()方法,我们来看看该方法做了什么。

    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

initialValue()方法为value设置了null值,通过当前线程获取threadLocals,若map存在则调用set方法,否则调用createMap方法创建threadLocals

总结和思考

从以上分析可以了解到,Thread对象持有自己的ThreadLocalMap对象,该对象实质为一个Entry数组,每个Entry是一个键值对,key是当前的ThreadLocal对象,并且对该ThreadLocal对象使用的是弱引用。这里存在两个问题:

  1. 为什么采用这种引用结构;
  2. 这里是否存在内存泄漏问题。

对于问题1,由于ThreadLocal的生命周期普遍长于Thread,因此当Thread生命周期结束以后,即使ThreadLocal仍存在,但由于弱引用的关系,ThreadLocalMap就可以被释放了。

低于问题2,当ThreadLocal提前于Thread结束生命周期,比如线程池这种Thread长期不结束的情况,此时ThreadLocal对象仅有来自ThreadLocalMapEntry的弱引用,因此该ThreadLocal对象时可以被回收掉的,那么接下来就会出现对应的Entry中key被置为null的情况,那么这个Entry就再也不可能被调用到,就发生了内存泄漏。为了处理这种情况,在源码的set方法中我们看到了大量的脏Entry清理策略,另外其实在remove方法中也有类似的清理策略,我们也在使用完ThreadLocal后采用手动调用remove方法的方式来避免内存泄漏的情况。

你可能感兴趣的:(ThreadLocal解析)