ThreadLocal 工作原理

一. 介绍

ThreadLocal 提供了线程局部(thread-local)变量。这些变量不同于普通变量,因为访问某个变量(通过其get或set方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal实例通常是类中的private static字段,它们希望将状态与某一个线程(例如,用户ID或事务ID)相关联。

ThreadLocal 适用于每个线程需要自己独立的实例且该实例需要在多个方法中被使用,也即变量在线程间隔离而在方法或类间共享的场景。

二. 使用示例

public class ThreadLocalDemo {

    private static ThreadLocal local = new ThreadLocal<>();

    public static void set() {
        local.set(Thread.currentThread().getName());
    }
    public static String getString() {
        return local.get();
    }

    public static void main(String[] args) throws Exception{
        set();
        System.out.println(getString());

        Thread thread0 = new Thread(() -> {
            set();

            System.out.println(getString());
        });

        Thread thread1 = new Thread(() -> {
           System.out.println(getString());
        });

        thread0.start();
        thread1.start();

        thread0.join();
        thread1.join();

        System.out.println(getString());
    }

}

输出:


image.png

从输出结果可以看出,虽然定义了一个全局静态变量的ThreadLocal 实例,但是每个线程里调用get 方法获取到的值是不一样的,这就可以看出ThreadLocal 提供线程级别局部变量的能力。

三. 工作原理

那么是怎么实现这种线程级别的局部变量的呢?实际上是每个线程内部维护了一个Map,这个Map 其实是ThreadLocal 的一个内部类ThreadLocalMap,他们之间的关系可以用下图表示,


image.png

也就是说这个map 的key 就是ThreadLocal 对象,value 就是通过ThreadLocal 的set 方法设置的值,因为这个map 是由线程负责维护的,而ThreadLocal 对象仅仅是作为key 来存在,所以也就实现了同一个ThreadLocal对象 在不同线程调用 get 方法获取到的值是不同的这样的线程级变量。

ThreadLocal:

刚才也提到,我们是通过调用ThreadLocal 对象的set 和get 方法来实现设置和获取变量值的,这个两个方法的实现代码如下,

public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);             // 首先获得当前线程的ThreadLocalMap 属性
        if (map != null)
            map.set(this, value);                           // map 不为空就给它set 值,key 就是当前ThreadLocal 对象,value 就是要设置的值
        else
            createMap(t, value);                            // 如果map 是空,就创建一个map,creatMap 方法的实现就是new 一个ThreadLocalMap 对象并赋值给线程的map属性
}

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);                                     // 首先获取当前线程的map 属性
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);    // 如果map 不为空就根据当前ThreadLocal 对象取出 Entry
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;                                      // 如果Entry 不为空就获得它的value 并进行类型转换
                return result;
            }
        }
        return setInitialValue();                                                   // 如果map 是空会返回设置初始值方法
}

private T setInitialValue() {                                                           // 设置初始值的方法主要是调用了 initialValue 方法,然后为map 赋值
        T value = initialValue();                                                   // initialValue 的默认实现是返回null
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
}

上面提到了initialValue 方法,在使用ThreadLocal 时,当你在一个线程里没有调用set 方法而是直接get 就会返回null(如二. 中示例的第三行输出),而如果在初始化ThreadLocal 时重写了该方法,那么如果没调用set 方法,就会获得initialValue 方法的返回值。如下

private static ThreadLocal local = new ThreadLocal(){           // 把上例中的初始化改为重写了initialValue 的形式
        @Override
        protected String initialValue() {
            return "default string";
        }
};

看一下输出:


image.png

可以看第三行,未调用set 方法的线程,调用get 获得的是initialValue 方法生成的值。

ThreadLocalMap:

ThreadLocalMap 并没有实现Map 接口,它内部定义了一个继承弱引用WeakReference 的Entry 类

static class Entry extends WeakReference> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
}

从上面代码可以看出,在Entry 这个键值对里,它key 其实是ThreadLocal 对象的弱引用,之所以这么设计,是为了避免内存泄漏。因为如果使用强引用,那么当ThreadLocal 对象不再使用时,因为Map 持有对它的强引用而导致ThreadLocal 对象无法释放,产生内存泄漏。

在ThreadLocalMap 中使用一个Entry 数组来保存数据。ThreadLocalMap 提供了两个重载的构造方法,其中ThreadLocal 的createMap 方法调用的是这一个:

ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];                                                        // table 就是保存数据的Entry 数组
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);  // 根据key 计算一个索引
            table[i] = new Entry(firstKey, firstValue);                                         // 创建一个Entry 并保存在数组中指定索引处
            size = 1;                                                                                                               // 设置size 属性
            setThreshold(INITIAL_CAPACITY);                                                                 // 设置扩容因子,该方法的实现就是将扩容因子设置为数组长度的2/3
}

从上面ThreadLocal 的代码中可以看出,它的set、get 方法实际上调用了ThreadLocalMap 的set 和getEntry 方法

private void set(ThreadLocal key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);                          // 先计算出该key 应该在的索引位置

            for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {     // nextIndex 计算下一个索引,就是在i 的基础上加1,如果超过len 则置为0
                ThreadLocal k = e.get();                                                 // 拿到当前位置的ThreadLocal 对象

                if (k == key) {                                                     // 如果和key 相等则直接替换value
                    e.value = value;
                    return;
                }

                if (k == null) {                                                    // 如果当前位置不存在一个key
                    replaceStaleEntry(key, value, i);           // 
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
}

replaceStaleEntry 方法是将set 操作期间遇到的旧Entry 替换为指定key 的Entry。无论指定key 的Entry 是否已存在,value参数中传递的值都存储在Entry 中。作为副作用,此方法将清除包含旧Entry 的“run”中的所有旧Entry。(run是两个空槽之间的一系列Entry)

private Entry getEntry(ThreadLocal key) {
            int i = key.threadLocalHashCode & (table.length - 1);           // 计算key 对应的索引
            Entry e = table[i];                                                     // 数组中取出Entry
            if (e != null && e.get() == key)                            // Entry 不为空就返回
                return e;
            else
                // Entry 为空就执行这个方法,该方法的实现就是从当前索引位置向后查找Entry,因为在设置的时候,如果当前索引位置有值
                // 会向后面的索引设置值,同时会清空key 为null 的值
                return getEntryAfterMiss(key, i, e);            
}

四. 注意事项

1. 内存泄漏

上面提到了ThreadLocalMap 使用了ThreadLocal 的弱引用作为key,来解决ThreadLocal 不再使用时而无法回收可能导致的泄漏。

但是这时还可能有另外一种泄漏情况,那就是因为ThreadLocal 对象可能被回收, 这样 ThreadLocalMap 中就会出现key 为null 的Entry,就没有办法访问这些key 为null 的Entry 的value,如果当前线程再迟迟不结束的话,这些key 为null 的Entry 的value 就会一直存在一条强引用链:Thread Ref -> Thread -> ThreadLocalMap -> Entry -> value 永远无法回收,造成内存泄漏。

ThreadLocal 解决第二种内存泄漏的方式,就是在set 和 get 时,清空key 为null 的value。

但是有的时候,ThreadLocal 的避免方法可能会无效,比如:使用static的ThreadLocal,延长了ThreadLocal的生命周期,可能导致的内存泄漏(案例http://blog.xiaohansong.com/2016/08/09/ThreadLocal-leak-analyze/)。

2. 在线程池中的使用

通过上面的原理分析,可以知道,ThreadLocal 是线程级变量,然而线程池中的线程是复用的,这就可能导致脏数据的问题,比如线程池中的线程A,为ThreadLocal 设置了值 a,然后在之后执行任务的代码用到了ThreadLocal,这时当A 执行完第一个任务,执行第二个任务时,还是要获取ThreadLocal 的数据,这时获取的还是a,但实际第二次任务需要的并不是a,这样就可能会导致严重的业务bug。示例:

public class JavaTest {

    private static ThreadLocal threadLocal = new ThreadLocal<>();

    public static void main(String[] args) throws Exception{
        ExecutorService service = Executors.newFixedThreadPool(3);      // 创建一个3 个线程的线程池

        for (int i = 0; i < 5; i++) {                                                                   // 执行5 个任务
            service.execute(new TestRunnable(i));
        }
    }

    static class TestRunnable implements Runnable {

        private int num;

        TestRunnable(int num) {
            this.num = num;
        }

        @Override
        public void run() {
            String taskValue = "task#" + num + "'s value";
            if (threadLocal.get() == null) {                                        // 模拟为ThreadLocal 赋值
                threadLocal.set(taskValue);
            }

            doSomething();                                                                          // 模拟业务代码
        }

        private void doSomething() {
            String taskName = "task#" + num;
            System.out.println("I am " + taskName + ", and I need " + threadLocal.get());           // 模拟使用ThreadLocal 中的值
        }
    }

}

看下输出:


image.png

可以看到数据出现了明显的错乱,task#3 需要的value 却是task#0 的value,这种错误是不能允许的,那么怎么避免呢,我们在doSomething 方法中加这么一句:

private void doSomething() {
            String taskName = "task#" + num;
            System.out.println("I am " + taskName + ", and I need " + threadLocal.get());
            threadLocal.remove();                           // 就是这句,清除ThreadLocal 保存的数据
}

下面再来看下输出:


image.png

嗯,task 的name 和需要的value 对应的整整齐齐,没毛病了。你可能觉得是我在给ThreadLocal 赋值时写的有问题,但实际在开发中,确实可能出现不能及时更新ThreadLocal 值的情况。

所以,使用ThreadLocal 的一个良好习惯就是在调用完 get 方法之后,在合适的时机调用 remove 方法,清空ThreadLocal。

在涉及到线程池时,要尤其注意以上点,否则可能出case。

你可能感兴趣的:(ThreadLocal 工作原理)