[Java源码][并发J.U.C]---解析ThreadLocal

前言

本文将以一个例子开头简单看看ThreadLocal类的特性,进而分析该类的源代码.

本文源码下载

例子

启动三个线程,每个线程的操作都是使用静态变量count把原先的值加1.

package com.com.example.threadlocal;

import java.util.concurrent.TimeUnit;
public class TestThreadLocal {

    static ThreadLocal count = new ThreadLocal(){
        protected Integer initialValue() {
            return 100;
        }
    };

    public static void main(String[] args) throws InterruptedException {
        for (int i = 0; i < 3; i++) {
            new Thread(new Runner(), "thread-" + i).start();
            TimeUnit.SECONDS.sleep(1);
        }
    }

    static class Runner implements Runnable {
        public void run() {
            for (int i = 0; i < 3; i++) {
                count.set(count.get() + 1);
                System.out.println(Thread.currentThread().getName() + ":" +
                        count.get());
            }
        }
    }
}

结果如下: 可以看到每个线程都有单独的一个count实例一样,这个就是threadlocal的特性可以使得线程之间隔离,相当于每个线程自己保存了一份自己的数据副本,在本线程中操作只会改变当前线程的值并不会影响其他线程的值.

thread-0:101
thread-0:102
thread-0:103
thread-1:101
thread-1:102
thread-1:103
thread-2:101
thread-2:102
thread-2:103

类图

下面的图是整个ThreadLocal涉及到的所有类. 接下来通过该图理解一下整体的操作.

[Java源码][并发J.U.C]---解析ThreadLocal_第1张图片
threadlocal.png

实现思路: 每个线程实体类中都保存着一个ThreadLocal.ThreadLocalMap用于存放该线程中所有的映射关系, 这个映射关系是由threadlocal类和初始化的value对应并放在ThreadLocalMap.Entry类中存放.
对上面的例子而言: thread-1.threadlocals存放了[count, 100]的映射关系, thread-2thread-3 各自的threadlocals都存放着[count, 100]的映射关系. 所以当thread-1运行run方法时循环了三次操作后thread-1.threadlocals存放了[count,103],而hread-2thread-3则保存不变.

1. 通过类ThreadLocal中的构造方法threadlocal()生成对象.
2. 通过initialValue初始化value值.
3. ThreadLocal类的set,get,remove操作都是调用的ThreadLocalMap的方法进行操作, 因为ThreadLocalMap定义了这些逻辑的核心实现, 那ThreadLocal类的方法做了什么呢? 主要是为了获取当前线程的ThreadLocalMap对象, 即当前线程的成员变量threadlocals, 通过该变量进行真正的逻辑操作.

所以接下来我们将简单看看ThreadLocal的方法, 重点分析的是ThreadLocalMap类.

ThreadLocal中的set, get, remove

set 方法的逻辑

// 获取线程t的ThreadLocalMap对象
ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
}
// 根据当前的ThreadLocal对象和firstValue为线程t的成员变量threadlocas生成一个ThreadLocalMap对象
void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
}
/**
  *  1. 获取当前线程
  *  2. 如果当前线程的ThreadLocalMap对象threadlocals已经存在,则直接调用ThreadLocalMap类的set方法
  *  3. 如果不存在,则创建一个ThreadLocalMap对象
  */
public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
}

get 方法的逻辑

// 设置初始化值并返回初始值
private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
}
// 取当前线程的ThreadLocalMap对象
// 如果不存在或者不存在当前ThreadLocal对象不在ThreadLocalMap中,调用setInitialValue()返回初始值
// 反之则返回当前线程的ThreadLocalMap对象中当前ThreadLocal对象对应的value值
public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
}

remove方法的逻辑

// 获取当前线程的ThreadLocalMap对象,如果不为空,则当前线程的ThreadLocalMap对象
//中的当前ThreadLocal对象所对应的节点.
public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
}

ThreaLocalMap中的方法

ThreadLocalMap是整个ThreadLocal的核心部分. 由于我把ThreadLocalMap的源码单独拿了出来(源码下载),接下来先由一个小例子简单测试一下.

在测试之前需要先看一下ThreadLocalMap中的Entry类继承了WeakReference类,请注意Entrykey也就是ThreadLocal对象是采用弱引用的方法,而value还是一个强引用. 关于弱引用可以关注我的另一个博客通过例子理解java强引用,软引用,弱引用,虚引用

static class Entry extends WeakReference> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
        }

下面的例子是生成了9(10的时候数组会扩展)个ThreadLocal对象并且全部都set到了ThreadLocalMap对象中,然后打印一下整个数组中的情况. 因为每个tls[i]对应的对象现在都有两个引用,1个强引用tls[i]和一个弱引用在某个entry[h]里面.所以我做了个简单测试把tls[4]这个强引用去除,然后主动调用gc后再次打印观察数组情况.

public static void test_2() {
        ThreadLocal [] tls = new ThreadLocal[9];
        for (int i = 0; i < tls.length; i++) {
            tls[i] = new ThreadLocal();
        }
        ThreadLocalMap map = new ThreadLocalMap(tls[0], 0);
        for (int i = 1; i < 9; i++) {
            System.out.print("i = " + i + ", hash = ");
            map.set(tls[i], i);
        }
        map.printEntry();
        tls[4] = null;
        System.gc();
        //map.set(tls[4], 4);
        System.out.println("---------------------------------");
        map.printEntry();
    }

输出: 被垃圾回收器回收了

i = 0, hash = 0 & (16 - 1) = 0
i = 1, hash = 1640531527 & (16 - 1) = 7
i = 2, hash = -1013904242 & (16 - 1) = 14
i = 3, hash = 626627285 & (16 - 1) = 5
i = 4, hash = -2027808484 & (16 - 1) = 12
i = 5, hash = -387276957 & (16 - 1) = 3
i = 6, hash = 1253254570 & (16 - 1) = 10
i = 7, hash = -1401181199 & (16 - 1) = 1
i = 8, hash = 239350328 & (16 - 1) = 8
table[0] = [com.sourcecode.threadlocal.ThreadLocal@6f94fa3e,0]
table[1] = [com.sourcecode.threadlocal.ThreadLocal@5e481248,7]
table[2] = null
table[3] = [com.sourcecode.threadlocal.ThreadLocal@66d3c617,5]
table[4] = null
table[5] = [com.sourcecode.threadlocal.ThreadLocal@63947c6b,3]
table[6] = null
table[7] = [com.sourcecode.threadlocal.ThreadLocal@2b193f2d,1]
table[8] = [com.sourcecode.threadlocal.ThreadLocal@355da254,8]
table[9] = null
table[10] = [com.sourcecode.threadlocal.ThreadLocal@4dc63996,6]
table[11] = null
table[12] = [com.sourcecode.threadlocal.ThreadLocal@d716361,4]
table[13] = null
table[14] = [com.sourcecode.threadlocal.ThreadLocal@6ff3c5b5,2]
table[15] = null
---------------------------------
table[0] = [com.sourcecode.threadlocal.ThreadLocal@6f94fa3e,0]
table[1] = [com.sourcecode.threadlocal.ThreadLocal@5e481248,7]
table[2] = null
table[3] = [com.sourcecode.threadlocal.ThreadLocal@66d3c617,5]
table[4] = null
table[5] = [com.sourcecode.threadlocal.ThreadLocal@63947c6b,3]
table[6] = null
table[7] = [com.sourcecode.threadlocal.ThreadLocal@2b193f2d,1]
table[8] = [com.sourcecode.threadlocal.ThreadLocal@355da254,8]
table[9] = null
table[10] = [com.sourcecode.threadlocal.ThreadLocal@4dc63996,6]
table[11] = null
table[12] = [null,4]
table[13] = null
table[14] = [com.sourcecode.threadlocal.ThreadLocal@6ff3c5b5,2]
table[15] = null

ThreadLocalMap 与 HashMap处理冲突不一样, HashMap采用的是拉链法,而ThreadLocal采用的开放地址法, 每个ThreadLocal对象都有一个threadLocalHashCode 通过 nextHashCode() 每生成一个ThreadLocal对象都在前面对象的threadLocalHashCode基础上加一个常量HASH_INCREMENT = 0x61c88647, (从上面的例子中也可以看出来.)至于为什么?应该是hash冲突的比较少,具体为什么我也不太清楚.

插入或者更新操作 set(ThreadLocal key, Object value)

/**
         * 作用: 将key和value 插入(如果key不存在)或者更新(如果key存在)
         * @param key    键
         * @param value  值
         */
        private void set(ThreadLocal key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            //System.out.format("%d & (%d - 1) = %d\n", key.threadLocalHashCode, len, i);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal k = e.get();
                // 如果key存在,则替换该值
                if (k == key) {
                    e.value = value;
                    return;
                }
                /**
                 * 如果当前k过期,则调用replaceStaleEntry方法
                 * 无论key是否存在,都会保存在位置i,具体细节可以看replaceStaleEntry的注释
                 */
                
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            /**
             *  有限次去查找过期节点并删除过期节点,如果有删除则返回
             *  如果没有删除则判断是否超过阀值
             *  如果超过阀值则调用rehash函数.
             */
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

作用: 将key和value 插入(如果key不存在)或者更新(如果key存在)

对应流程图如下
[Java源码][并发J.U.C]---解析ThreadLocal_第2张图片
set.png

expungeStaleEntry(int staleSlot)

/**
         *
         * 作用: 从该索引staleSlot往下直到遇到null结束返回当前下标,遇到的过期元素tab[i]设置为null,遇到的正常节点做rehash.
         * @param staleSlot 需要清理的位置, 一个已经确定过期的位置
         * @return 返回从staleSlot位置开始第一个为entry值为null的位置
         */
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            /**
             *  清除该staleSlot的值
             */
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            /**
             *  从stateSlot 开始往下继续搜索
             *  1. 如果为null, 直接退出
             *  2. 如果虚引用对应的key已经为null,也就是被垃圾回收器回收了,则清除该位置
             *  3. 如果不是1或者2,表明该位置存着一个正常值,观察是否需要rehash,因为取值的时候会方便
             *     因为该类处理hash冲突使用的是:开放定址法
             */
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    /**
                     *  因为处理冲突使用的开放地址法, 现在已经删除了一个位置,
                     *  并且该节点前面的节点有可能为null,因为k==null的时候会把tab[i]=null,
                     *  所以比如下次set操作对该key进行操作的时候就找不到该key,因为前面有null值,
                     *  会认为该key不存在,重新创建一个新的节点,因此会造成有两个节点拥有同一个key.
                     *
                     *  所以需要进行rehash
                     *
                     *  因此之前有些位置因为冲突没有存放到对应的hash值该有的位置,
                     *  所以下面的方法就是检查并且把此对象存到对应的hash值的位置或者它的后面.
                     */
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // 往下继续寻找,值到找到为null的空位置,然后把只放进去
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

作用: 从该索引往下直到遇到null结束返回当前下标,遇到的过期元素tab[i]设置为null,遇到的正常节点做rehash.

为什么需要做rehash?

因为处理冲突使用的开放地址法, 现在已经删除了一个位置, 并且该节点前面的节点有可能为null,因为k==null的时候(表明该节点已经过期)会把tab[i]=null, 所以比如下次set操作对该key进行操作的时候就找不到该key,因为前面有null值, 会认为该key不存在,然后重新创建一个新的节点,因此会造成有两个节点拥有同一个key. 所以需要进行rehash.

对应流程图如下
[Java源码][并发J.U.C]---解析ThreadLocal_第3张图片
expungeStaleEntry.png

cleanSomeSlots

/**
         *
         * @param i 从该位置i的下一个位置开始
         * @param n n >>>= 1决定尝试的次数
         * @return 返回是否有清除过陈旧的值
         */
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                // 获取下一个位置
                i = nextIndex(i, len);
                Entry e = tab[i];
                // 如果当前节点不为null,并且对应的key已经被垃圾回收器收集
                if (e != null && e.get() == null) {
                    // 重新设置n 和 设置removed标志位为true
                    n = len;
                    removed = true;
                    // 清除陈旧的位置节点i, 并设置i为当前i下一个位置开始第一个为entry值为null的位置
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

作用:
1. 尽可能多的删除过期的节点.
2. 检查次数由n决定. 为logn或者n.
3. 返回是否有删除过期元素

对应流程图如下
[Java源码][并发J.U.C]---解析ThreadLocal_第4张图片
cleanSomeSlots.png

replaceStaleEntry

/**
         *
         * 将set操作期间遇到的过期节点替换为指定键的节点。
         * 无论指定键的节点是否已存在,value参数中传递的值都存储在节点中。
         * 作为副作用,此方法将清除包含过期节点的“run”中的所有过期节点。 (run是两个空槽之间的一系列节点。)
         *
         *
         * @param key         节点的键
         * @param value       节点的值
         * @param staleSlot   在寻找key过程中遇到的第一个过期的节点
         */
        private void replaceStaleEntry(ThreadLocal key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            /**
             * 备份以检查当前"run"中的先前失效节点。
             * 我们一次清理整个"run",以避免由于垃圾收集器释放串联的refs(即,每当收集器运行时)不断的增量重复。
             * slotToExpung 始终代表着整个run里面的第一个过期节点.
             */
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            /**
             *   寻找"run"中的key 或者第一个空节点(null)
             *   1. 找到key的位置i,就交换tab[i]和tab[staleSlot],提高查找时候的命中率
             *   2. 如果找到一个空节点,就表示该key之前没有插入到该tab中过,跳出循环后创建一个新的节点(key,value)
              */

            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();

                /**
                 * 如果我们找到键,那么我们需要将它与陈旧条目交换以维护哈希表顺序。
                 * 新陈旧的插槽或任何其他过期的插槽,在它上面遇到,然后可以发送到expungeStaleEntry
                 * 删除或重新运行run中的所有其他条目。
                 *
                 */
                if (k == key) {
                    e.value = value; // 替换value

                    tab[i] = tab[staleSlot];  // 交换
                    tab[staleSlot] = e;

                    /**
                     * 如果slotToExpunge == staleSlot
                     * 表明当前的i是整个run里面的第一个过期的元素节点,更新一下slotToExpunge即可.
                     */
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                /**
                 *  如果在反向扫描中找不到过期的节点, 那么在扫描key是看到的
                 *  第一个过期节点就是整个run里面的过期节点
                 */
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            /**
             *  如果key没有找到,表明该key是第一次存入到该table中,
             *  则生成一个新的节点并放到staleSlot的位置.
             */
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            /**
             * 如果staleSlot不是该run里面的唯一一个过期节点,
             * 则都需要进行清除工作
             */
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

作用: 替代原有key的节点或者新生成节点后做清除过期节点操作.
1. 定义了slotToExpunge为整个run(run是两个空槽之间的一系列节点.)里面的第一个过期节点.
2. 查找该key是否之前有存入过. 如果存在则存入到位置staleSlot并清除从slotToExpunge开始该run里面的过期元素.
3. 如果不存在该key则创建一个新节点并放到位置staleSlot,如果staleSlot是整个run里面的唯一一个过期节点,则不需要清除,否则需要清除从slotToExpunge开始该run里面的过期元素.

详细操作可以看代码注释和下面的流程图.

对应流程图如下
[Java源码][并发J.U.C]---解析ThreadLocal_第5张图片
replaceStaleEntry.png

rehash, resize, expungeStaleEntries

/**
         * 作用:
         * 1. 先对整个数组的过期节点进行清除
         * 2. 判断是否需要对数组进行扩展
         */
        private void rehash() {
            /**
             * 先对整个数组的过期节点进行清除
             */
            expungeStaleEntries();

            /**
             *  size >= 0.75 * threshold 则扩大容量
             */

            if (size >= threshold - threshold / 4)
                resize();
        }

        /**
         *  作用: 扩展数组
         *  size扩大两倍, 每一个正常的元素做rehash映射到新的数组中
         *  每一个过期的元素的value都设置为null方便gc
         */
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }

        /**
         *  作用: 从头到尾扫描整个数组对所有过期节点做清理工作
         */
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }

取值操作 getEntry, getEntryAfterMiss

/**
         * 作用:  先用hash定位寻找key,如果找到key 返回该节点
         *       如果没有找到key 返回getEntryAfterMiss(key, i, e)的结果
         */
        private Entry getEntry(ThreadLocal key) {
            int i = key.threadLocalHashCode & (table.length - 1); //计算hash值
            Entry e = table[i];
            if (e != null && e.get() == key) // 如果命中
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

        /**
         * 作用:  利用开发地址法寻找key,如果找到key 返回该节点
         *       如果没有找到key 返回null
         */
        private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal k = e.get();
                if (k == key) //找到key存在的位置,直接返回节点
                    return e;
                /**
                 *  如果当前节点的key过期,则调用expungeStaleEntry(i)进行清理当前位置
                 *  并且不接受返回值,i 没有发生变化
                 *
                 *  如果不过期则取下一个节点
                 */
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

作用: 根据key获得对应的节点,如果不存在则返回null.

对应流程图如下
[Java源码][并发J.U.C]---解析ThreadLocal_第6张图片
getEntry.png

删除操作 remove

        /**
         *
         * 作用: 删除key,调用了expungeStaleEntry(i)做清除和rehash工作,
         *
         * @param key 要删除的键值
         */
        private void remove(ThreadLocal key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1); //获取hash值, 如果不在该位置则继续往下找直到遇到null
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i); //做清除工作
                    return;
                }
            }
        }

作用: 删除该key所对应的节点.

为什么需要调用expungeStaleEntry(i)?

这是因为在清除了位置i后, 整个run(run是两个空槽之间的一系列节点.)该位置会变为null,因此位置i后面的节点需要做rehash, 这是因为该数组处理hash冲突采用的是开放地址法,因此后面的节点的hash值有可能不在它本身所处的位置, 如果后面的某一个节点K本身的hash值在i前面(比如i-1,还是在整个run里面), 那么后续操作在对节点K更新或者获取操作时就会找不到节点K, 因为取hash值是i-1,检查到i时发现已经为null,所以会认为该节点K不存在. 所以可以看到整个ThreadLocalMap对过期元素做删除操作都是调用expungeStaleEntry(i)方法.

关于为什么使用弱引用和内存泄露的问题?

可以参考该文章: 深入分析 ThreadLocal 内存泄漏问题
和 深入分析 ThreadLocal 内存泄漏问题

你可能感兴趣的:([Java源码][并发J.U.C]---解析ThreadLocal)