ThreadLocal原理解析

什么是ThreadLocal

ThreadLocal用于储存专属于某个线程变量的值(线程私有)。同一个ThreadLocal变量,在不同线程下读取到的变量值是不同的,可以做到变量在线程之间的隔离。和传统方式定义的变量不同,传统方式的成员变量是多个线程共享的。

ThreadLocal的使用方法

定义ThreadLocal变量

ThreadLocal最好使用static类型声明。具体原因在后面源代码分析中解释。

Java 8之前使用下面的方式定义ThreadLocal并指定初始值

private static ThreadLocal local = new ThreadLocal() {
    @Override
    protected Integer initialValue() {
        return 0;
    }
};

在java 8 之后推荐使用如下的方式:

private static final ThreadLocal local = ThreadLocal.withInitial(() -> 0);

注意:必须使用以上两种方式之一来指定初始值。举一个反例,比如下面的代码:

static ThreadLocal local = new ThreadLocal<>();
static Object o = new Object();

public static void someMethod() {
    if (null == local.get()) {
        local.set(o);
    }
}

这段逻辑看似为ThreadLocal指定了默认值。但是实际运行时,每个线程持有的ThreadLocal的值都是同一个Object。不同线程之间的变量仍然是共用的,没有线程隔离。一定不要这样使用ThreadLocal

读写ThreadLocal变量

这里比较简单,直接代码说明。

// 写入变量值
local.set(2);

// 读取变量值
Integer i = local.get();

清除ThreadLocal变量值

线程不再使用ThreadLocal变量,需要调用remove方法,否则会发生内存泄漏。

local.remove()

ThreadLocal的原理

我们从get读取变量值这个方法入手分析。

get方法代码如下:

public T get() {
    // 获取当前Thread
    Thread t = Thread.currentThread();
    // 获取ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 如果map存在
        // 获取该ThreadLocal对应的MapEntry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            // 如果该entry存在,返回entry对应的value
            T result = (T)e.value;
            return result;
        }
    }
    // 否则,执行设置初始值的逻辑
    return setInitialValue();
}

由以上代码可知ThreadLocalMap是变量值的载体。我们看一下getMap方法:

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

返回的是Thread类里面threadLocals的值。

这里的关系可能比较复杂。ThreadLocalMap在Thread类中保存。ThreadLocal变量获取值的时候,先获取当前线程的ThreadLocalMap,在从这个ThreadLocalMap中获取key为调用get方法的ThreadLocal变量所对应的value。我们可以得出如下结论:

  • ThreadLocal变量的值分别在各个Thread中保存。
  • 同一个ThreadThreadLocalMap保存了该Thread在多处ThreadLocal变量中的对应的值。

ThreadLocalMapThreadLocal的一个静态内部类。和java.util.Map类似,ThreadLocalMap也拥有Entry。如下所示:

static class Entry extends WeakReference> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal k, Object v) {
        super(k);
        value = v;
    }
}

每一个Entry对象保存了一个ThreadLocalvalue的对应关系。

这里有一个最特殊的地方,Entry继承了WeakReference。说到WeakReference自然需要提到StrongReference。但Java并没有StrongReference这个类。我们平时写的代码,例如:

Object o = new Object();

这里的o就是一个Strong Reference,即强引用。无论什么时候,只要对象存在强引用,GC时候都不会被回收。

但是WeakReference则不同,如果一个对象仅仅被WeakReference引用,那么GC的时候,该对象会被回收。一个例子如下:

public class WeakRefDemo {
    public static void main(String[] args) {
        WeakDemo weakDemo = new WeakDemo();
        System.out.println(weakDemo.strongReference);
        System.out.println(weakDemo.weakReference.get());
// (1)        Object o = weakDemo.weakReference.get();
        System.gc();
        System.out.println(weakDemo.strongReference);
        System.out.println(weakDemo.weakReference.get());
    }
}

class WeakDemo {
    Object strongReference = new Object();
    WeakReference weakReference = new WeakReference<>(new Object());
}

为了对比结果,WeakDemo类中同时定义了强引用和弱引用。
保持(1)处注释不动,运行代码,会得到类似如下输出:

java.lang.Object@1b6d3586
java.lang.Object@4554617c
java.lang.Object@1b6d3586
null

我们发现GC过后,strongReference依然可访问,然而weakReference已经被回收,值变成了null。

如果取消(1)这一行的注释,再次执行代码,会得到类似如下的输出:

java.lang.Object@1b6d3586
java.lang.Object@4554617c
java.lang.Object@1b6d3586
java.lang.Object@4554617c

和上一次不同,这次在GC之前weakReference引用的对象在别处存在强引用,因此它不再被GC回收。

我们回到ThreadLocalMapEntry类。Entry是一个指向ThreadLocalWeakReference。而定义ThreadLocal的对象会持有对ThreadLocal的强引用。如果Entry指向ThreadLocal不使用WeakReference,即便是定义了ThreadLocal的对象不再使用,只要线程不销毁,还是能够通过Thread -> ThreadLocalMap -> Entry -> ThreadLocal -> 定义ThreadLocal的对象这条引用链追溯到,因此会有严重的内存泄漏问题。

一开始提到ThreadLocal最好使用static变量类型。因为static修饰符避免了不同的实例创建出不同的ThreadLocal变量。虽然不添加static修饰也不影响使用,但是会造成变量浪费。ThreadLocal变量真正的内容不是在ThreadLocal中存储,而是在各个线程自己的ThreadLocalMap中。所以说建议使用static修饰ThreadLocal变量。

我们继续分析ThreadLocalget方法。如果ThreadThreadLocalMap为null,或者是线程的ThreadLocalMap中不存在key为这个ThreadLocal变量的entry,会执行设置初始值的操作。方法代码如下所示:

private T setInitialValue() {
    // 这里调用的是使用方法里介绍的,设置ThreadLocal初始值的方法
    // 设定初始值需要继承ThreadLocal类,并覆盖这个方法
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        // 如果map存在,在map中设置ThreadLocal和value的映射关系
        // 当然这个方法还隐藏了其他逻辑,后面分析
        map.set(this, value);
    else
        // 如果map为null,为线程创建一个ThreadLocalMap
        // 并创建一个entry,保存当前ThreadLocal和value的对应关系
        createMap(t, value);
    return value;
}

我们看一下ThreadLocalMapset方法(ThreadLocalset方法也间接调用了该方法):

private void set(ThreadLocal key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    // 获取ThreadLocalMap的容量,初始值为16
    int len = tab.length;
    
    // 每一个ThreadLocal变量都有一个独一无二的hashCode
    // 该值与map容量减一按位与之后,得到的值换算为10进制,作为该ThreadLocal变量值在map中对应entry的下标存储
    // 无论map怎么扩容,内部table的length总是2的n次方数,减去1之后可以获取到一个每一位全是1的二进制数
    // ThreadLocal和这个数按位与之后可以在table中分布的更为平均,尽量避免hash碰撞
    int i = key.threadLocalHashCode & (len-1);

    // nextIndex获取的是下一个index(++index),如果++index越界,返回0,数组从头开始
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal k = e.get();

        // 如果k是当前的ThreadLocal变量,说明找到了当前ThreadLocal对应的entry,更新它的value并返回
        if (k == key) {
            e.value = value;
            return;
        }
        // 如果k为null,说明之前这个entry对应的ThreadLocal变量已经被回收
        // key已经被回收的entry在源代码中称为stale entry
        // 这个stale entry会被替换掉
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
        // 如果这两个if都没有执行,说明可能存在hash碰撞,即其他ThreadLocal的hashcode运算之后的下标和当前ThreadLocal运算的结果一致
        // 并且其他ThreadLocal的变量已经在map中储存
        // 这时候尝试继续寻找key对应的entry
        // 也可能是key对应的entry没有创建
        // 这种情况会一直到for循环执行完毕,在下面步骤创建出新的entry
    }

    // 如果下标对应的entry不存在,创建一个新的
    tab[i] = new Entry(key, value);
    // map的大小加一
    int sz = ++size;
    // 查找当前index之后,以2为底sz的对数个entry,如果有stale entry,清除他们,具体稍后分析
    // 这里之所以没有扫描所有的stale entry,是为了平衡清除stale entry操作和时间的消耗
    // 如果没有发现stale entry,判断sz是否大于阈值
    // 阈值为map容量的三分之二
    // 如果超过了阈值,清除所有的stale entry,并且再次判断是否需要扩容
    // rehash流程稍后分析
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

接下来分析下replaceStaleEntry的代码。该方法参数中存放stale entry的index称为staleSlot

private void replaceStaleEntry(ThreadLocal key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    int slotToExpunge = staleSlot;
    // 从staleSlot位置向前查找其他stale slot,直到发现前面entry的slot不存在为止
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first
    // 从stale位置的下一个index开始循环
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
        // 如果找到了一个entry,key是当前的ThreadLocal
        if (k == key) {
            // 替换value为新的值
            e.value = value;

            // 和stale entry交换位置
            // 因为当前ThreadLocal对应的entry的位置(ThreadLocal的hashCode决定,之前已分析)本来就应该是stale entry的位置
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            // 如果slotToExpunge值没有改变,那么就从下标i开始清理stale entry
            // 因为i处的entry已经被交换为stale entry
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // expungeStaleEntry清理stale slot到下一个null slot之前所有的stale entry(左闭右开区间)
            // 返回值是下一个null slot(下标)
            // 再扫描下一个null slot(下标)往后以2为底len的对数个slot内所有的stale entry并清除
            // cleanSomeSlots还包含其他行为,后续分析
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        // 如果在前面的for循环(倒着查找stale entry)没有发现其他的stale entry
        // 并且当前index正好是stale slot
        // 设置slotToExpunge为当前index
        // 因为该if进入之后slotToExpunge会被修改,之后不会再次进入
        // 所以说这里设置slotToExpunge为staleSlot到下一个null slot之间(开区间)的第一个stale slot
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    // 如果没有发现key为当前ThreadLocal的entry,创建一个新的entry
    tab[staleSlot].value = null
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    // 如果有其他的stale entry,运行清理方法
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

经过分析可知replaceStaleEntry从staleSlot位置向前查找stale entry,直到遇到null slot为止。比如:

stale | null | stale | entry | stale | stale | entry
  A      B       C       D       E       F       G

map中存储有上面所示的多个entry。执行replaceStaleEntry传入的是F位置。那么根据以上逻辑,slotToExpunge最终会指向C。
接下来replaceStaleEntry会从C位置开始清除stale slot。

我们分析下清理stale entry的逻辑。位于expungeStaleEntry方法。代码如下所示:

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    // 设置stale entry的value为null,释放掉value的引用
    tab[staleSlot].value = null;
    // 释放掉stale entry的引用
    tab[staleSlot] = null;
    // 减小size
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    // 从stale slot下一个index开始循环,直到entry为null为止
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal k = e.get();
        if (k == null) {
            // 如果发现key为null,说明又发现了一个stale entry,执行数据擦除
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 根据key对应的hash,计算它的slot位置(下标)
            int h = k.threadLocalHashCode & (len - 1);
            // 如果key不在它本应该属于的下标位置
            if (h != i) {
                // 清除下标i对应的entry
                // 重新计算安排entry的位置
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                // 从h位置向后查找,直到发现空位置
                // 把现在现在遍历到的这个entry放置在这个空位置
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    // 返回方法参数中staleSlot后面第一个空slot的下标
    return i;
}

由以上分析可知expungeStaleEntry不仅清理了staleSlot下标对应的entry,还顺便清理了staleSlot到下一个null slot之间的所有stale slot。除此之外还重新分配上述区间内实际存储下标和从hashCode计算出的下标不一致的entry的位置。

接下来需要分析cleanSomeSlots方法。该方法从参数i这个下标开始(不包括i),向后查找以2为底n的对数个slot,如果中间发现有stale slot,调用expungeStaleEntry方法清除。同时重置n为map的容量(即需要再扫描至少log2(容量)-1个entry)。

正如英文注释所说,该方法之所以没有一开始就遍历整个map去清除stale entry,是因为需要从性能方面考虑作出权衡。

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    // 获取map的容量
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 如果e为stale entry
            // 重置n为map容量
            n = len;
            // 设置removed标记为true
            removed = true;
            // 清除下标i位置对应的entry
            i = expungeStaleEntry(i);
        }
        // n每次循环都除以2
    } while ( (n >>>= 1) != 0);
    // 如果有entry被擦除,返回true
    return removed;
}

还剩下一个rehash方法。rehash方法负责清理map中所有的stale entry。如果清理过后map的已用空间还是过大(超过阈值的四分之三),会进行扩容操作。

rehash方法代码如下:

private void rehash() {
    // 清理所有的stale slot
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    // 如果已用空间仍然大于等于阈值的四分之三,执行扩容操作
    if (size >= threshold - threshold / 4)
        resize();
}

expungeStaleEntries代码如下所示:

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

该方法遍历所有的entry,清理stale entry。

最后是负责扩容,重新计算slot位置的resize方法代码:

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    // 扩容为原来的2倍
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    // 遍历老的map entry
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal k = e.get();
            if (k == null) {
                // 如果是stale entry,清理掉他的value
                e.value = null; // Help the GC
            } else {
                // 根据hashCode和新的length计算出entry所属的slot(下标)
                int h = k.threadLocalHashCode & (newLen - 1);
                // 如果计算出来的slot被占用(发生hash碰撞),逐个向后找到一个空闲的slot
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                // 放置entry到空闲slot
                newTab[h] = e;
                // 计数加一
                count++;
            }
        }
    }

    // 重新设置新的阈值
    setThreshold(newLen);
    // 设置新的size和table
    size = count;
    table = newTab;
}

Java8设置ThreadLocal初始值的逻辑

在Java8之后不推荐使用继承ThreadLocal重写initialValue方法的方式来指定初始值。

Java8建议使用withInitial静态方法,提供一个Supplier方法(无参数有返回值)作为默认值生成器。

withInitial方法代码如下所示:

public static  ThreadLocal withInitial(Supplier supplier) {
    return new SuppliedThreadLocal<>(supplier);
}

方法返回了一个SuppliedThreadLocal类型。我们查看下它的代码:

static final class SuppliedThreadLocal extends ThreadLocal {

    private final Supplier supplier;

    SuppliedThreadLocal(Supplier supplier) {
        this.supplier = Objects.requireNonNull(supplier);
    }

    @Override
    protected T initialValue() {
        return supplier.get();
    }
}

该类继承了ThreadLocal。重写的initialValue方法调用supplier的get方法并返回。实际上和JDK8之前的使用方式没有区别,只不过Java帮我们做了一层封装,可以用更为优雅的方式指定初始值。

ThreadLocal内存泄漏问题

文章开始的时候介绍entry继承了WeakReference。这样使用的目的是为了帮助GC回收ThreadLocal变量所在对象。这是因为Thread -> ThreadLocalMap -> Entry -> ThreadLocal -> ThreadLocal变量所在对象这一条引用链中Entry -> ThreadLocal这一环是弱引用。尽管如此,如果不恰当使用ThreadLocal,内存泄漏问题依然会存在,因为entry对象本身并不会因为弱引用的缘故自动回收。

按照JDK文档,线程不再使用ThreadLocal变量的时候,需要调用remove方法,清除对应的entry释放内存。避免形成内存泄漏。

remove方法。该方法间接调用了ThreadLocalMapremove方法。如下所示:

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

ThreadLocalMapremove方法和解释如下所示:

private void remove(ThreadLocal key) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算key对应的下标
    int i = key.threadLocalHashCode & (len-1);
    // 从下标i处向后逐个遍历tab中的entry,直到遇到null entry
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            // 如果找到需要清理的entry,清理它的引用
            e.clear();
            // 清理这个stale entry
            expungeStaleEntry(i);
            // 最后返回
            return;
        }
    }
}

本文为原创内容,欢迎大家讨论、批评指正与转载。转载时请注明出处。

你可能感兴趣的:(ThreadLocal原理解析)