ThreadLocal 源码分析

一、概述

是不是觉得它是一个线程?不要被名字迷惑,它并不是一个线程。

在《从源码理解Android Handler消息机制》一文中,我们提到ThreadLocal,当时我们这么解释:ThreadLocal 你可以理解为保存一个在线程范围内可见的变量。那么ThreadLocal是如何做到的呢?Follow Me ,看看源码如何实现的。

二、源码分析

平常我们使用ThreadLocal都是调用其set()和get()方法,基于这两个方法为切入点我们来分析下它的实现原理。

老规矩,源码是最好的解释,直接上源码:

代码 1.1
    public void set(T value) {
        Thread t = Thread.currentThread();//获取当前调用的线程
        ThreadLocalMap map = getMap(t);//往下面看
        if (map != null)
            map.set(this, value);//直接往map添加数据 查看代码1.3
        else
            createMap(t, value);//查看代码1.2
    }

    ThreadLocalMap getMap(Thread t) {
         //直接返回线程的一个变量 我们发现是 ThreadLocal.ThreadLocalMap threadLocals = null;
        return t.threadLocals;
    }

static class ThreadLocalMap {}//名字叫Map 并没有实现Map接口

上边的set()方法里主要内容:

  1. 获取线程的ThreadLocalMap threadLocals 对象;
  2. 根据threadLocals 是否为空来决定是创建ThreadLocalMap 还是往ThreadLocalMap 添加对象;

下边看下createMap方法:

代码 1.2
 void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);//生成ThreadLocalMap
    }

  ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {//ThreadLocal为key
         table = new Entry[INITIAL_CAPACITY];
        //注意ThreadLocal的threadLocalHashCode 
         int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
         table[i] = new Entry(firstKey, firstValue);
         size = 1;
         setThreshold(INITIAL_CAPACITY);
    }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;//必须是二的幂

createMap()方法里主要做了:

  1. 生成ThreadLocalMap实例;
  2. 用ThreadLocal作为key,然后生成一个节点放入数组,至于数组位置,则由ThreadLocal的threadLocalHashCode&(INITIAL_CAPACITY -1)决定;
  3. INITIAL_CAPACITY 这个值必须是2的幂,初始为16;
  4. 神奇的 0x61c88647 ,每当我们new一个ThreadLocal对象,新对象的threadLocalHashCode值等于在静态变量nextHashCode变量上加 0x61c88647,至于原因看下边的数据测试:
public class ThreadLocal {
 
    private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode = new AtomicInteger();//原子变量  通过CAS操作更新

    private static final int HASH_INCREMENT = 0x61c88647;

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    //我们看下 0x61c88647如何神奇:
  public static void main(String[] args) {
       int hashCode = 0x61c88647;
        
       System.out.println("数组length 为 16 ");
       for(int i =0;i<16;i++){
           System.out.print((15&(i*hashCode))+"  ");
       }
       
        System.out.println("");
        System.out.println("数组length 为 32 ");
       
        for(int i =0;i<32;i++){
            System.out.print((31&(i*hashCode))+"  ");
        }
    }
运行结果:
数组length 为 16 
0  7  14  5  12  3  10  1  8  15  6  13  4  11  2  9  
数组length 为 32 
0  7  14  21  28  3  10  17  24  31  6  13  20  27  2  9  16  23  30  5  12  19  26  1  8  15  22  29  4  11  18  25  

结果很神奇,这个跟数学相关,我也不是很清楚为什么,总之运行结果是散列的分散在数组中。

接下来我们看下ThreadLocalMap 的 set()方法:

代码1.3

        private void set(ThreadLocal key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);//有没有很熟悉

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {//线性探测法
                ThreadLocal k = e.get();

                if (k == key) {//key相同 替换值
                    e.value = value;
                    return;
                }
                 //Entry 集成自 WeakReference  k很有可能为null
                if (k == null) {
                    replaceStaleEntry(key, value, i);//查看代码1.4
                    return;
                }
            }

            tab[i] = new Entry(key, value);//表里没数据  生成节点加入进去
            int sz = ++size;//更改当前size
            if (!cleanSomeSlots(i, sz) && sz >= threshold)//判断是否触发阈值 触发则扩容
                rehash();//查看代码1.7
        }

主要用线性探测法向数组中确定节点位置,与HashMap的链地址法实现方式不一样。

代码1.4
   //replaceStaleEntry方法  
        private void replaceStaleEntry(ThreadLocal key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            int slotToExpunge = staleSlot;//记录删除节点位置起始位置 
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))//从数组往前找 有节点但节点无key值则更新slotToExpunge ,否则停止查找
                if (e.get() == null)
                    slotToExpunge = i;

            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {//线性探索查找key相同节点
                ThreadLocal k = e.get();

                if (k == key) {//如果 k == key 则更新value 讲该节点更新到 staleSlot位置上 
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                   //清除部分节点expungeStaleEntry()查看代码1.5 cleanSomeSlots()查看代码1.6
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);//未找到 生成新节点放入

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
代码1.5 方法expungeStaleEntry()
       // 从删除节点到后边遍历 到第一个为 null节点之间的节点都经过检测 返回第一个null节点位置
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;//删除该位置节点
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {//从当前节点现象探索遍历
                ThreadLocal k = e.get();
                if (k == null) {//删除 key为null节点 
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);//根据数组长度计算出新的位置 因为数组有可能扩容
                    if (h != i) {//不相等的情况需要改变该节点位置
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)//讲h位置节点进行线性探测法确定位置
                            h = nextIndex(h, len);
                        tab[h] = e;//讲e节点更新到h位置
                    }
                }
            }
            return i;
        }
代码 1.6 方法cleanSomeSlots() 用于清除线性探索方向上的空节点
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);// 见 2.2.2 
                }
            } while ( (n >>>= 1) != 0);//n = n>>>1 无符号右移动并赋值 这边每次除以2有点不太理解 欢迎大家讨论
            return removed;
        }
代码 1.7
        private void rehash() {
            expungeStaleEntries(); //见下边

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();//见1.8
        }

        /**
         * Expunge all stale entries in the table.
         */
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);//见1.5
            }
        }
    }
代码1.8 resize()
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;//扩容 容量依然是2的幂
            Entry[] newTab = new Entry[newLen];//生成新的数组
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {//遍历旧的数组
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);//计算在新数组中的位置
                        while (newTab[h] != null)// 冲突后 线性探测法
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
            //更新属性
            setThreshold(newLen);
            size = count;
            table = newTab;
        }

上边就是ThreadLocal中set()方法的实现,主要: 向数组中插入节点,根据key (ThreadLocal)的threadLocalHashCode&(len-1)决定位置,然后根据线性探索法解决冲突问题,包括如果数组size超过阈值则扩容。

下边分析下get()方法:

代码2.1
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);//查看2.2
            if (e != null)
                return (T)e.value;
        }
        return setInitialValue();//这是一个空方法,如果未命中则调用用该方法返回的默认value
    }
代码2.2
        private Entry getEntry(ThreadLocal key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;//命中
            else
                return getEntryAfterMiss(key, i, e);//未命中 见下方
        }

        private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {//线性探索查找
                ThreadLocal k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;//未找到返回null
        }

通过get()我们可以看出:

  1. 根据key (ThreadLocal)的threadLocalHashCode&(len-1)位置的值是否命中,命中返回,没有命中则根据线性探索法查找节点;
  2. 第一步没找到则调用setInitialValue()方法返回值来充当返回值,该方法用户可以重写;

下边看下remove()方法

代码3.1
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);// 见下方
     }

        private void remove(ThreadLocal key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);//熟不熟悉 
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {//线性探索法查找
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);//见1.5
                    return;
                }
            }
        }

三、总结

上边就是为大家分析的ThreadLocal的实现,主要实现依靠:

  1. 每个线程保留一个ThreadLocalMap 变量;
  2. 当我们向ThreadLocal中放入值的时候,其实我们是将值放入到了Thread的threadLocals中;
  3. 没当我们实例一个ThreadLocal的时候,该实例的threadLocalHashCode值会改变,ThreadLocalMap中的table数组长度记为len,则不同实例的threadLocalHashCode&(len-1)会散列在table数组的不同位置;
  4. ThreadLocalMap中table属性中的Entry继承自WeakReference,所以key很容易被回收;
  5. 当出现hash冲突时,是使用线性探索法查找,不同于HashMap的查找原理;

以上就是为大家分享的ThreadLocal源码分析。感谢你的耐心阅读,如有错误,欢迎指正。如果本文对你有帮助,记得点赞。欢迎关注我的微信公众号:


ThreadLocal 源码分析_第1张图片
qrcode_for_gh_84a02a29fedd_430.jpg

你可能感兴趣的:(ThreadLocal 源码分析)