Java多线程(11)——ThreadLocal源码剖析

目录

1.概述

2.图解+源码分析ThreadLocal原理

2.1 Thread类的两个ThreadLocalMap类型的参数

2.2 ThreadLocalMap详解

(1)成员变量与内部类

(2)构造方法

(3)获取前一个/后一个索引的方法和设置扩容阈值的方法

(4)getEntry方法

(5)ThreadLocal的内存泄露

(6)set

(7)remove

2.3 ThreadLocal详解

(1)get()和set()

(2)setInitialValue()

(3)remove()

(4)threadLocalHashCode

3.ThreadLocal不支持继承性

3.1 ThreadLocal不支持继承性的演示

3.2 支持继承的InheritableThreadLocal


1.概述

  • ThreadLocal是一个本地线程副本变量工具类

前面的博文中有讲到过,引发线程安全的原因主要是

  • 1.存在共享资源
  • 2.存在多个线程去操作同一个共享资源

示例如下:

package ThreadLocal;

class ThreadLocalDemo{
    //共享资源
    private int count = 20;

    //写操作
    public int decrement(){
        try {
            Thread.sleep(100);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        count--;
        return count;
    }

    public static void main(String[] args) {
        ThreadLocalDemo demo = new ThreadLocalDemo();
        //线程A共享资源
        new Thread(()->{
            for (int i = 0; i < 10; i++) {
                System.out.println(Thread.currentThread().getName() + "\t" + demo.decrement());
            }
        },"A").start();
        //线程B共享资源
        new Thread(()->{
            for (int i = 0; i < 10; i++) {
                System.out.println(Thread.currentThread().getName() + "\t" + demo.decrement());
            }
        },"B").start();
    }

}

Java多线程(11)——ThreadLocal源码剖析_第1张图片

而我们前面的思路是通过sychronized锁或者CAS无锁策略来控制多个线程的执行顺序来保证数据的一致性,即我们前面的思路都是从引发线程的第二个原因入手的。

那么我们能否从第一个原因入手呢?既然是共享资源引发的问题,我们能不能让它不是共享资源呢?ThreadLocal提供了线程安全的另一种思路,即:

  • 通常情况下,我们创建的变量是可以被任何一个线程访问并修改的。而ThreadLocal让各个线程都拥有一份线程私有的数据,让每个线程绑定自己的值,线程在操作数据的时候,仅仅是操作自己线程内部的变量,这样线程之间的变量互不干扰,在高并发场景下,可以实现无状态的调用。

如果你创建了一个ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的本地副本,这也是ThreadLocal变量名的由来。他们可以使用 get() 和 set() 方法来获取默认值或将其值更改为当前线程所存的副本的值,从而避免了线程安全问题。

案例:

package ThreadLocal;

class ThreadLocalDemo{
    private ThreadLocal count = new ThreadLocal(){

        //重写ThreadLocal的初始化方法,用于初始化ThreadLocal变量
        @Override
        protected Integer initialValue() {
            return 0;
        }
    };


    public int getNext(){
        Integer value = count.get();
        value++;
        count.set(value);
        return value;
    }

    public static void main(String[] args) {
        ThreadLocalDemo demo = new ThreadLocalDemo();

        new Thread(()->{
            for (int i = 0; i < 5; i++) {
                System.out.println(Thread.currentThread().getName() + "\t" + demo.getNext());
            }
        },"A").start();

        new Thread(()->{
            for (int i = 0; i < 5; i++) {
                System.out.println(Thread.currentThread().getName() + "\t" + demo.getNext());
            }
        },"B").start();

        new Thread(()->{
            for (int i = 0; i < 5; i++) {
                System.out.println(Thread.currentThread().getName() + "\t" + demo.getNext());
            }
        },"C").start();
    }

}

Java多线程(11)——ThreadLocal源码剖析_第2张图片

2.图解+源码分析ThreadLocal原理

Java多线程(11)——ThreadLocal源码剖析_第3张图片

以下分析结合上图去理解:

我们先从上往下去分析:

2.1 Thread类的两个ThreadLocalMap类型的参数

  • 1.每个线程Thread都有两个自己的属性——threadlocalsinheritableThreadLocals,它是ThreadLocalMap类型的,Thread类中源码如下:

Java多线程(11)——ThreadLocal源码剖析_第4张图片

  • 而从源码也可以看出ThreadLocalMap是ThreadLocal的内部类,我们可以把 ThreadLocalMap 理解为ThreadLocal 类实现的定制化的 HashMap
  • 默认情况下这两个变量都是null,只有当前线程调用 ThreadLocal 类的 setget方法时才创建它们,实际上调用这两个方法的时候,我们调用的是ThreadLocalMap类对应的 get()set()方法。

2.2 ThreadLocalMap详解

  • 2.我们先来看一看ThreadLocalMap这个内部类

(1)成员变量与内部类

static class ThreadLocalMap {

    /**
     * 定义了数组中存储的对象——键值对   键:ThreadLocal,值:value
     *
     */
    static class Entry extends WeakReference> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
     * 初始化数组的容量
     */
    private static final int INITIAL_CAPACITY = 16;

    /**
     * 存放Entry的数组
     */
    private Entry[] table;

    /**
     * 数组中存放的Entry的个数
     */
    private int size = 0;

    /**
     * 扩容阈值
     *
     * 默认为0
     */
    private int threshold; // Default to 0
}
  • 可以发现ThreadLocalMap的底层存放的是数组,而该数组中存放的元素是键值对,即Entry对象
  • 而Entry是ThreadLocalMap中定义的静态内部类,它继承自WeakReference,即弱引用
    • 这时,会奇怪在Entry中没有看到有定义key字段呢?
    • Java多线程(11)——ThreadLocal源码剖析_第5张图片
    • 其实可以看到在Entry的构造方法中,调用了super(k),即调用了WeakReference的构造方法
    • 同理,WeakReference又调用了Reference的构造方法
    • Java多线程(11)——ThreadLocal源码剖析_第6张图片
    • 可以发现在Reference的构造方法中,最终将k赋值给了Reference中定义的字段referent
    • 所以最终ThreadLocalMap中的key是referent字段,由于它是从WeakReference继承下来的,所以,key是一个弱引用
  • 以上几个其他参数,与HashMap中的类似,size为元素个数,thresold为扩容阈值 

(2)构造方法

    /**
     * 用一组键值对初始化ThreadLocalMap
     */
    ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
        //创建默认容量16的Entry数组
        table = new Entry[INITIAL_CAPACITY];
        /**
         * 通过传入的ThreadLocal的threadLocalHashCode值计算它存放在数组中的索引
         *
         */

        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);

        //将该Entry存放进该索引位置
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        //设置初始的扩容阈值,此处为16*2/3=10
        setThreshold(INITIAL_CAPACITY);
    }

    /**
     * 通过父线程的ThreadLocalMap创建的ThreadLocalMap
     *
     * 此方法只在ThreadLocal的createInheritedMap中被调用
     */
    private ThreadLocalMap(ThreadLocalMap parentMap) {
        //根据父线程的ThreadLocalMap的参数创建新的大小的数组,并设置扩容阈值为与父线程扩容阈值相同
        Entry[] parentTable = parentMap.table;
        int len = parentTable.length;
        setThreshold(len);
        table = new Entry[len];

        //遍历父线程的ThreadLocal底层的Entry数组
        for (int j = 0; j < len; j++) {
            Entry e = parentTable[j];
            if (e != null) {
                @SuppressWarnings("unchecked")
                ThreadLocal key = (ThreadLocal) e.get();
                if (key != null) {
                    /**
                     * childValue是留给子类去重写的一个方法,此处用了模板方法模式
                     *
                     * 在ThreadLocal的其中一个子类InheritableThreadLoca中,是直接返回了参数中传入的父类的值
                     */
                    Object value = key.childValue(e.value);
                    Entry c = new Entry(key, value);
                    int h = key.threadLocalHashCode & (len - 1);
                    //从计算到的索引开始,不断往后计算新的索引,找到第一个没有Entry即指向null的位置,将该键值对插入
                    while (table[h] != null)
                        h = nextIndex(h, len);
                    table[h] = c;
                    size++;
                }
            }
        }
    }
  • 这里重点说一下第二个构造方法,它是通过传入父类线程的ThreadMap来去创建一个ThreadMap,这个方法只在ThreadLocal的createInheritedMap方法中被调用了
  • Java多线程(11)——ThreadLocal源码剖析_第7张图片
  • 而createInheritedMap方法在哪里被使用了呢?
  • 关于这里调用这个方法的作用见下文关于ThreadLocal不支持继承性的讲解

(3)获取前一个/后一个索引的方法和设置扩容阈值的方法

    /**
     * 将扩容阈值设置为传入长度len的2/3
     */
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

    /**
     * 增加i,但是最大增加到len-1,当i=len的时候,i又回到0
     */
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

    /**
     * 减少i,但是最小减少到0,当i=-1时,i回到len-1
     */
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
        Thread
    }

以下是三个核心方法getEntry,set,remove

(4)getEntry方法

    /**
     * 获取指定ThreadLocal所对应的Entry  
     */
    private ThreadLocal.ThreadLocalMap.Entry getEntry(ThreadLocal key) {
        //计算传入的key在数组中的索引
        int i = key.threadLocalHashCode & (table.length - 1);
        //获取到该索引处的Entry
        ThreadLocal.ThreadLocalMap.Entry e = table[i];
        //如果该Entry不为空,并且此Entry的key与我们传入的Entry相等的话,此Entry就是我们要找的
        if (e != null && e.get() == key)
            return e;
        //没有在通过key计算的索引位置找到Entry的话,调用getEntryAfterMiss
        else
            return getEntryAfterMiss(key, i, e);
    }
  • 上述代码中e.get()实际上调用的Reference中的get方法,所以e.get()返回的是ThreadLocalMap的key
  • 上述代码中为什么会有在原来的key中获取不到Entry的情况呢?为什么有两个判断条件呢?,只要e存在不就是它要取的Entry?
  • 答案是否定的,是在判断我们传入的对应的key本身就不存在的情况,
  • 而关于的解释就要引出另一个问题了——内存泄漏

(5)ThreadLocal的内存泄露

  • ThreadLocalMap 中使用的 key 为 ThreadLocal 的弱引用,value 是强引用。
  • 所以,如果 ThreadLocal 没有被外部强引用的情况下,在垃圾回收的时候会 key 会被清理掉,而 value 不会被清理掉。这样一来,ThreadLocalMap 中就会出现key为null的Entry。
  • 假如我们不做任何措施的话,value 永远无法被GC 回收,这个时候就可能会产生内存泄露。ThreadLocalMap实现中已经考虑了这种情况,在调用 set()get()remove() 方法的时候,会清理掉 key 为 null 的记录。使用完 ThreadLocal方法后最好手动调用remove()方法

上述就是判断key是否已经被经过了处理,rehash或者已经被清除,然后是的话,就调用getEntryAfterMiss方法

    /**
     * key没有在通过它的threadLocalHashCode计算出的直接索引地方找到,调用此方法
     */
    private ThreadLocal.ThreadLocalMap.Entry getEntryAfterMiss(ThreadLocal key, int i, ThreadLocal.ThreadLocalMap.Entry e) {
        
        ThreadLocal.ThreadLocalMap.Entry[] tab = table;
        int len = tab.length;
        //从e开始遍历Entry数组,直到遇到第一个空的Entry为止
        while (e != null) {
            //拿到此位置Entry的key
            ThreadLocal k = e.get();
            //当前位置的key是我们要找的key,那么当前就返回当前的Entry,它就是我们要找的
            if (k == key)
                return e;
            //此位置的key为null,代表此位置的Entry已经过期
            if (k == null)
                //调用一次从该位置删除过期Entry的方法
                expungeStaleEntry(i);
            //此位置的key不为null,也不是我们要找的,那就继续迭代
            else
                i = nextIndex(i, len);
            e = tab[i];
        }
        //最终都没有找到的话,返回null
        return null;
    }

(6)set

    /**
     * 将键值对插入数组中
     *
     * @param key the thread local object
     * @param value the value to be set
     */
    private void set(ThreadLocal key, Object value) {

        // We don't use a fast path as with get() because it is at
        // least as common to use set() to create new entries as
        // it is to replace existing ones, in which case, a fast
        // path would fail more often than not.

        ThreadLocal.ThreadLocalMap.Entry[] tab = table;
        int len = tab.length;
        //计算key对应的数组索引位置
        int i = key.threadLocalHashCode & (len-1);

        //从该索引位置开始遍历,直到遇到null为止
        for (ThreadLocal.ThreadLocalMap.Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal k = e.get();
            //当前Entry的key是我们要找的key的话,直接修改的替代它的值即可
            if (k == key) {
                e.value = value;
                return;
            }
            //key已经过期,它使用replaceStaleEntry进行处理
            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }
        //始终没有找到key并且没有遇到key为null的话,创建新的Entry,放到为null的该为止
        tab[i] = new ThreadLocal.ThreadLocalMap.Entry(key, value);
        int sz = ++size;
        //当数组中不存在过期了的key的话,并且新的元素个数大于扩容阈值,就进行rehash
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }

而rehash做了两件事:

  • 删除所有过期了的key对应的Entry,并进行rehash(实际是在expungeStaleEntry中进行的)
  • 扩容
    /**
     * 重新放置数组中元素的位置或者进行扩容
     */
    private void rehash() {
        //删除所有的过期Entry
        expungeStaleEntries();

        // Use lower threshold for doubling to avoid hysteresis
        if (size >= threshold - threshold / 4)
            //扩容
            resize();
    }

    /**
     * 删除所有的过期Entry
     */
    private void expungeStaleEntries() {
        ThreadLocal.ThreadLocalMap.Entry[] tab = table;
        int len = tab.length;
        for (int j = 0; j < len; j++) {
            ThreadLocal.ThreadLocalMap.Entry e = tab[j];
            if (e != null && e.get() == null)
                expungeStaleEntry(j);
        }
    }


        /**
     * 先删除我们传进来的位置的过期Entry
     * 然后真正Rehash是在此方法中进行的
     */
    private int expungeStaleEntry(int staleSlot) {
        ThreadLocal.ThreadLocalMap.Entry[] tab = table;
        int len = tab.length;

        //删除掉staleSlot位置的Entry
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;

        // Rehash until we encounter null
        ThreadLocal.ThreadLocalMap.Entry e;
        int i;
        //从staleSlot的下一个位置开始遍历Entry,直到某个位置遇到Entry等于null
        for (i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) {
            //获取当前Entry的ThreadLocal
            ThreadLocal k = e.get();
            //如果该key是空的,将它的值也置空,将整个数组该位置也置空
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            //如果该key不是空的
            } else {
                //取得它通过HashCode计算的索引值
                int h = k.threadLocalHashCode & (len - 1);
                //如果通过HashCode计算的值不等于当前的索引
                if (h != i) {
                    //将当前索引位置的值置空
                    tab[i] = null;

                    //从h开始遍历寻找下一个为null的位置
                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    //将该原来在位置i处的Entry存放到新的null位置
                    tab[h] = e;
                }
            }
        }
        return i;
    }

扩容方法

    /**
     * 扩容为原来数组的两倍
     */
    private void resize() {
        ThreadLocal.ThreadLocalMap.Entry[] oldTab = table;
        int oldLen = oldTab.length;
        //扩容为原来两倍
        int newLen = oldLen * 2;
        ThreadLocal.ThreadLocalMap.Entry[] newTab = new ThreadLocal.ThreadLocalMap.Entry[newLen];
        int count = 0;
        //遍历将原来数组中key不为null的Entry复制到新数组当中,
        for (int j = 0; j < oldLen; ++j) {
            ThreadLocal.ThreadLocalMap.Entry e = oldTab[j];
            if (e != null) {
                ThreadLocal k = e.get();
                if (k == null) {
                    e.value = null; // Help the GC
                } else {
                    int h = k.threadLocalHashCode & (newLen - 1);
                    while (newTab[h] != null)
                        h = nextIndex(h, newLen);
                    newTab[h] = e;
                    count++;
                }
            }
        }
        //将各个参数更新
        setThreshold(newLen);
        size = count;
        table = newTab;
    }

(7)remove

    /**
     * 删除key对应的Entry
     */
    private void remove(ThreadLocal key) {
        ThreadLocal.ThreadLocalMap.Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        //遍历查找到该key的Entry,然后删除该Entry
        for (ThreadLocal.ThreadLocalMap.Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            if (e.get() == key) {
                e.clear();
                expungeStaleEntry(i);
                return;
            }
        }
    }

2.3 ThreadLocal详解

当看了ThreadMap的底层源码,我们再来看ThreadLocal就简单多了,最终的变量是放在了当前线程的 ThreadLocalMap 中,并不是存在 ThreadLocal 上,ThreadLocal 可以理解为只是ThreadLocalMap的封装,传递了变量值。

ThrealLocal 类中可以通过Thread.currentThread()获取到当前线程对象后,直接通过getMap(Thread t)可以访问到该线程的ThreadLocalMap对象。

下面我们再来看看ThreadLocal类提供如下几个核心方法:

  • get()方法用于获取当前线程的副本变量值。
  • set()方法用于保存当前线程的副本变量值。
  • initialValue()为当前线程初始副本变量值。
  • remove()方法移除当前前程的副本变量值。

通过下面也可以发现它们其实都是在调用ThreadLocalMap中对应get,set,remove方法

(1)get()和set()

    /**
     * 返回ThreadLocal变量在当前线程对应的threadLocals中存储对应的的value
     */
    public T get() {

        Thread t = Thread.currentThread();
        //1.获取当前线程的ThreadLocalMap对象threadLocals
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //2.从map中获取线程存储的K-V Entry节点。
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                //3.从Entry节点获取存储的Value副本值返回。
                T result = (T)e.value;
                return result;
            }
        }
        //map为空的话,创建map,设置初始值,如果我们没有重写initialValue方法的话,初始值默认为null
        //所以我们在使用中要判断是否为空指针NullPointerException。
        return setInitialValue();
    }
    /**
     * 将当前的ThreadLocal和value一起存放到当前线程的threadlocals字段中(ThreadLocalMap中)
     */
    public void set(T value) {
        Thread t = Thread.currentThread();
        //1.获取当前线程的成员变量map
        ThreadLocalMap map = getMap(t);
        //2.map非空,则重新将ThreadLocal和新的value副本放入到map中。
        if (map != null)
            map.set(this, value);
        //3.map空,则对线程的成员变量ThreadLocalMap进行初始化创建,并将ThreadLocal和value副本放入map中。
        else
            createMap(t, value);
    }
  • 当我们定义一个ThreadLocal变量
  • 当我们获取或者修改它的值的时候,它会先获取当前线程
  • 每个线程又存放着一个ThreadLocalMap,该ThreadMap中存放着键值对,键就是ThreadLocal,值就是我们的变量值
  • 我们就获取每个线程中的ThreadLocalMap
  • 然后我们根据当前调用的ThreadLocal就可以找到该值,或者存储当前ThreadLocal和值组成的键值对
  • 总结:
    • 可以发现,实际上是通过ThreadLocal对象本身和当前线程结合在一起唯一确定了该值

上述用到的getMap和createMap方法如下:

    /**
     * 返回Thread类的threadLocals字段
     */
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    /**
     * 使用ThreadLocalMap的第一个构造方法创建一个ThreadLocalMap赋给Thread类中的threadLocals的引用
     */
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

(2)setInitialValue()

    /**
     * 设置当前ThreadLocal对象的初始值
     */
    private T setInitialValue() {
        //通过initialValue获取值
        T value = initialValue();
        //将当前ThreadLocal对象和value加入到当前线程的threadlocals中去
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        //返回初始值
        return value;
    }

而initialValue()方法默认是返回null,我们可以在定义ThreadLocal的时候重写该方法,设置返回值

(3)remove()

    /**
     * 在当前线程的threadlocals中删除当前ThreadLocal对象对应的键值对
     */
    public void remove() {
        ThreadLocalMap m = getMap(Thread.currentThread());
        if (m != null)
            m.remove(this);
    }

(4)threadLocalHashCode

我们在前面的ThreadLocalMap中大量的使用了threadLocalHashCode来计算数组的索引

将threadLocalHashCode进行一个位运算得到数组索引

threadLocalHashCode代码如下:

Java多线程(11)——ThreadLocal源码剖析_第8张图片

  • 由于nextHashCode是static的,从属于类,所以在一个程序中,它只在类的初始化阶段调用new AtormicInteger()初始化为0,而后每一次新建ThreadLocal对象的时候,由于threadLocalHashCode是非static的,所以每减一次对象,都调用一次nextHashCode()方法自增一次,增量为0x61c88647。

0x61c88647是斐波那契散列乘数,它的优点是通过它散列(hash)出来的结果分布会比较均匀,可以很大程度上避免hash冲突,已初始容量16为例,hash并与15位运算计算数组下标结果如下:

hashCode 数组下标
0x61c88647 7
0xc3910c8e 14
0x255992d5 5
0x8722191c 12
0xe8ea9f63 3
0x4ab325aa 10
0xac7babf1 1
0xe443238 8
0x700cb87f 15

 

 

 

 

 

 

 

 

 

 

总结:结合最开始的图看

  • 1.对于同一个ThreadLocal来讲,他的索引值i是确定的,因为都是通过相同的上述方法和值计算出来的,在不同线程之间访问时访问的是不同的table数组的同一位置即都为table[i],只不过这个不同线程之间的table是独立的。
  • 2.对于同一线程的不同ThreadLocal来讲,这些ThreadLocal实例共享一个table数组,然后每个ThreadLocal实例在table中的索引i是不同的。

3.ThreadLocal不支持继承性

3.1 ThreadLocal不支持继承性的演示

ThreadLocal案例:

package ThreadLocal;

public class TestThreadLocal {


    public static ThreadLocal threadLocal = new ThreadLocal();

    public static void main(String[] args) {
        threadLocal.set("hello world");

        new Thread(()->{
            System.out.println(Thread.currentThread().getName()+":"+threadLocal.get());
        },"A").start();

        System.out.println("main:"+threadLocal.get());

    }

}

Java多线程(11)——ThreadLocal源码剖析_第9张图片

  • 上述代码的逻辑
    • 先在类中定义了一个ThreadLocal字段
    • 然后在主线程中对该ThreadLocal使用set进行赋值
    • 然后线程A和线程main中分别使用get方法去访问它
  • 从源码角度分析上述的执行过程
    • 在main线程中调用set方法,而在set方法中会找到main线程的threadLocals,然后将我们定义的threadlocal和“hello world”键值对,存放到main线程的threadLocals
    • 然后当我们在线程A中调用get的时候,它是在线程A自己的threadLocals中查找threadLocal对应的值,而在main线程中调用get的时候,它是在main线程中自己的threadLocals中查找threadLocal对应的值,即它们是在不同的ThreadMap中查找
    • 而set只设置了main线程的值,线程A中的并未进行任何设置,所以获取到的是默认值null

所以ThreadLocal本身并不能让子线程访问到父线程中的ThreadLocal变量,这也即ThreadLocal的不支持继承性

3.2 支持继承的InheritableThreadLocal

为了解决上述问题,就有了InheritableThreadLocal

  • InheritableThreadLocal类继承自ThreadLocal,它提供了一个特性,就是让子线程可以访问在父线程中设置的本地变量

整个InheritableThreadLocal的实现如下:

public class InheritableThreadLocal extends ThreadLocal {
    /**
     * 重写childValue方法,该方法直接返回,父线程传入的值
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    /**
     * 重写父线程中getMap方法
     * 
     * 它返回的是Thread类中的inheritableThreadLocals的引用
     */
    ThreadLocalMap getMap(Thread t) {
        return t.inheritableThreadLocals;
    }

    /**
     * 使用ThreadLocalMap的第一个构造方法创建一个ThreadLocalMap赋给Thread类中的inheritableThreadLocals的引用
     */
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}
  • 可以看到该类只是重写了ThreadLocal的三个方法,我们前面提到过,Thread类拥有两个ThreadLocalMap类型的字段
  • Java多线程(11)——ThreadLocal源码剖析_第10张图片
  • 由于上述三个方法的重写,InheritableThreadLocal第一次set调用createMap创建的是inheritableThreadLocals而不再是threadLocals,当调用get方法获取当前线程内部的map变量时,获取的是inheritableThreadLocals而不再是threadLocals

综上:在InheritableThreadLocal的世界里,变量inheritableThreadLocals替代了threadLocals

使用InheritableThreadLocal改进上述案例:

package ThreadLocal;

public class TestInheritableThreadLocal {


    public static ThreadLocal inheritableThreadLocal = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        inheritableThreadLocal .set("hello world");

        new Thread(()->{
            System.out.println(Thread.currentThread().getName()+":"+inheritableThreadLocal .get());
        },"A").start();

        System.out.println("main:"+inheritableThreadLocal .get());

    }

}

Java多线程(11)——ThreadLocal源码剖析_第11张图片

  • 很容易发现,此代码只是把上述代码中创建的ThreadLocal对象替换成了创建InheritableThreadLocal对象

分析此代码的执行过程:

  • 我们先要看一下Thread类的构造方法
  • Java多线程(11)——ThreadLocal源码剖析_第12张图片
    • 可以看到在创建Thread的时候,都会进行调用init方法,而init方法中只要父线程的inheritableThreadLocals字段不为空,就会有通过ThreadLocalMap中的第二个构造方法(上述在介绍ThreadLocalMap的时候有分析过的)来创建一个ThreadLocalMap赋值给新建的Thread(即子线程)的inheritableThreadLocals
  • 我们分析代码中的执行流程:
    • 当我在main线程中调用set方法时,set方法中会将键值对存放到main线程的inheritableThreadLocals字段
    • 然后我们创建线程A的时候,由于父线程main(我们在main线程中调用它的构造方法创建,所以上述的获取到的就是main线程,即父线程就是main线程)
    • 所以此时判断满足
    • 然后就将父线程的通过ThreadLocaMap复制创建新的ThreadLocalMap赋值给正在创建的线程A的inheritableThreadLocals字段,所以这样线程A也就有了相同的inheritableThreadLocals
    • 然后分别在main线程和线程A中调用get均是获取他们各自的inheritableThreadLocals字段,它们拥有相同的值
  • 总结:在子线程中能够访问父线程的InheritableThreadLocal变量是因为使用了父线程inheritableThreadLocals通过复制创建了新的它自己的ThreadLocalMap赋值给了它自己的inheritableThreadLocals

你可能感兴趣的:(#,Java多线程)