Java并发: ThreadLoal与内存泄露

Java并发中提供了ThreadLocal类用于存放线程本地的对象, 顾名思义每个线程都会有一个独立的实例, 线程之间相互不会影响, 由此保证了线程安全. 除了学习这个类本身的线程安全机制以外, 这个的实现也是弱引用以及内存泄露处理的非常好的例子, 在面试中面试官在提问这个类的相关知识点时, 很可能也会想了解候选人对四种Java中的引用以及内存泄露的场景的掌握. 因此本文从ThreadLocal切入, 从源码的级别挖掘一下包含的知识点和指的借鉴的地方.

首先我们来看一个ThreadLocal的使用例子,

public class TestThreadLocal {
    public static void main(String[] args) throws InterruptedException {
        ThreadLocal<Boolean> a = new ThreadLocal<>();
        a.set(false);
        System.out.println(a.get());
        Thread t = new Thread(()->{
            a.set(true);
            System.out.println(a.get());
        });
        t.start();
        t.join();
        System.out.println(a.get());
    }
}

/*
main: false
Thread-0: true
main: false
*/

ThreadLocal是线程私有资源的一种抽象, 那么ThreadLocal内部是如何实现线程间相互不干扰的呢?

ThreadLocal的实例化与成员方法

ThreadLocal提供的api并不多, 包含了实例化一个泛型类的ThreadLocal和删改查这个对象的方法

  • ThreadLocal t = new ThreadLocal<>() 实例化一个Integer的ThreadLocal对象
  • get() 得到这个ThreadLocal的实际对象值
  • set(Integer value) 修改当前线程的这个ThreadLocal的值
  • remove() 删除当前线程的这个ThreadLocal

一个一个方法来看, ThreadLocal的构造器是一个空方法, 创建一个ThreadLocal实例, 之后是set方法:

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

这个方法虽然很短, 但是已经揭开了ThreadLocal保证线程私有的方法的面纱, 实现获得当前线程, 然后从当前线程中获得一个ThreadLocalMap类的实例, 如果这个map存在则之间在map中添加这个value, 否则创建一个map的实例, 并将value放到map中.

我们再去Thread类中确认一下

public
class Thread implements Runnable {
	// ...
	
    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;
    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
    //...
}

果然在每个Thread中都有一个叫做ThreadLocalMap的类的实例作为成员变量, 不同的Thread的map彼此独立, 都独立维护了自己的键值对, 一个ThreadLocal变量实际上作为key和value构成 键值对, 每一个线程的map中都存了一份. 另外, Thread的类定义中还有一个inheritableThreadLocals, 它是什么以后再说.

那么我们接下来想一想, get方法应该怎么实现, 对于每个线程, 都应该拿着自己线程的这个ThreadLocal对象作为键去自己线程的ThreadLocalMap中查询value获得值, 那么get方法是这样做的么, 没错就是这样.

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

然后我们需要完善一些细节, 首先在set的时候, map的构造是懒加载的模式, 如果map为null, 则实例化一个新的map, 这个实例化的过程由createMap完成, 由于是线程安全的, 所以并不需要进行同步等, 具体的ThreadLocalMap的结构我们会在下一节讨论. 然后是get时如果没有这个key, 则会将这个key按照初始化的值加入到map中, 内含了一个set(null)的过程.

    private T setInitialValue() {
        T value = initialValue(); // 默认返回的是null
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

ThreadLocalMap类

接下来讨论一下ThreadLocalMap这个类的实现, 这个类实际上就是一个HashMap, 但是它并没有采用util中的HashMap的类进行实例化, 而是自己实现了一套HashMap. 提到hash方法, 一定需要考虑的就是hash冲突的解决方案, 我们知道, 在HashMap中采用的是拉链法解决, 而这里实现者选择了向后线性探测的方法, 这种方法相比拉链法, 在冲突少, 负载低的时候查询速度更快, 但是随着负载提高, 冲突变多, 线性探测的方法会带来比较大的时间开销, 极端情况查找变成了O(n)的时间复杂度. 此外, 线性探测的方法在删除元素时, 也会比拉链法更复杂点, 这个我们之后会看到.

在进一步深入之前, 我们还是来看下这个类中的重要常量, 首先和HashMap类似的, 该类的初始数组大小也是16, 并且其增长也是按照2的倍数增长, 并没有数组缩小的方法. 按2的幂增长的一个重要好处是可以用size-1与元素的hash做与操作代替取余操作, 节省了计算.此外, 它的负载阈值设为2/3, 但是在某些特殊情况下会减小为1/2, 在后文会更加具体的分析.

在继续深入ThreadLocalMap的基本结构之前, 我们首先回顾一下java中的四种引用的知识点.

强引用, 软引用, 弱引用, 幻象引用

  • 所谓强引用(“Strong” Reference),就是我们最常见的普通对象引用,只要还有强引用指向一个对象,就能表明对象还“活着”,垃圾收集器不会碰这种对象。对于一个普通的对象,如果没有其他的引用关系,只要超过了引用的作用域或者显式地将相应(强)引用赋值为 null,就是可以被垃圾收集的了,当然具体回收时机还是要看垃圾收集策略。
  • 软引用(SoftReference),是一种相对强引用弱化一些的引用,可以让对象豁免一些垃圾收集,只有当 JVM 认为内存不足时,才会去试图回收软引用指向的对象。JVM 会确保在抛出 OutOfMemoryError 之前,清理软引用指向的对象。软引用通常用来实现内存敏感的缓存,如果还有空闲内存,就可以暂时保留缓存,当内存不足时清理掉,这样就保证了使用缓存的同时,不会耗尽内存。
  • 弱引用(WeakReference)并不能使对象豁免垃圾收集,仅仅是提供一种访问在弱引用状态下对象的途径。这就可以用来构建一种没有特定约束的关系,比如,维护一种非强制性的映射关系,如果试图获取时对象还在,就使用它,否则重现实例化。它同样是很多缓存实现的选择
  • 对于幻象引用,有时候也翻译成虚引用,你不能通过它访问对象。幻象引用仅仅是提供了一种确保对象被 finalize 以后,做某些事情的机制,比如,通常用来做所谓的 Post-Mortem 清理机制,我在专栏上一讲中介绍的 Java 平台自身 Cleaner 机制等,也有人利用幻象引用监控对象的创建和销毁。

当弱引用变量引用的对象没有其他强引用继续引用时, 该对象在下一次gc时就会被收回. ThreadLocalMap中就是用到了弱引用的这种特点.

hash表的基本结构

补充完四种引用类型, 我们接着分析ThreadLocalMap的内部结构. 我们前面说到, ThreadLocalMap是每个线程持有的一个哈希表, 用来保存ThreadLocal在本线程中的副本. 哈希表的本质就是一个数组, 在该类中, 数组的元素Entry有一些特别, 以下是Entry的定义

	static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

Entry是一个静态内部类, 继承了WeakReference, 这就是我们前面说的弱引用, WeakReferenceReference的子类, 从名字我们可以看出, 他们都是对引用这个概念的抽象, 在Reference中有一个成员变量private T referent; 这个变量指向new出来的ThreadLocal对象, 而Entry的私有成员value指向set的参数指向的对象, 构成map的一个键值对.

可能的内存泄露

上文中提到了, 只被弱引用引用的对象会在下一次gc中被清理, 在Entry中, ThreadLocal自身是一个弱引用, 而value却是一个强引用, 这个强引用在ThreadLocal.get()的方法中被发布出去, 这个发布可能会造成和ThreadLocal不一致的生命周期, 例如, 从主线程传进来的对象作为value, 其生命周期是与主进程一致, 而当前方法的ThreadLocal局部变量的生命周期与当前方法一致, 这种不一致就可能导致内存泄露. 例如下面这种情况.

public class TestThreadLocal {
    public static void main(String[] args) throws InterruptedException, ExecutionException {
        ExecutorService pool = Executors.newSingleThreadExecutor();
        Future<Boolean> res = pool.submit(new MyTask()); // 长生命周期对象
        Boolean result = res.get(); // pool的线程中的a对应的Entry的value被发布出来
        System.gc(); // pool持有的线程中没有强引用的ThreadLocal a会被gc收回, 但是value被发布出来了, 有强引用
        >(打断点) System.out.println(result);
        pool.shutdown();
    }
    static class MyTask implements Callable<Boolean> {
        @Override
        public Boolean call() throws Exception {
            ThreadLocal<Boolean>  a = new ThreadLocal<>(); // 离开callable后 a就没有强引用了
            a.set(true);
            return a.get();
        }
    }
}

Java并发: ThreadLoal与内存泄露_第1张图片
上面这张图中, referent也就是ThreadLocal引用的对象随着run()函数结束就变成垃圾了, 但是value此时并不能被回收, 因此就导致了内存泄露. 要处理这种内存泄露, 最保险的方法就是要求我们在使用ThreadLocal时一定要记得显式remove. 为了防止因为用户忘了remove导致的严重的内存泄露问题, ThreadLocal中进行了比较复杂的设计.

接下来会首先分析下哈希表的增删改查, 然后再专门分析下ThreadLocal是怎么解决内存泄露的.

哈希冲突的处理

在哈希表中, 按照key进行set的过程, 就是对key的hash进行取余, 找到对应的位置, 并放置. 但是实际情况下, 正如我们前面说的, 可能发生哈希冲突, 这里是通过线性探测的方式在循环数组上找到第一个可以插入的位置进行解决的.

private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1); // 计算出位置

            for (Entry e = tab[i];
                 e != null; // 有hash冲突
                 e = tab[i = nextIndex(i, len)]) {  // 循环的向后查找下一个位置
                ThreadLocal<?> k = e.get();

                if (k == key) { // 如果找到了, 之前就存在, 直接替换value, 返回
                    e.value = value;
                    return;
                }

                if (k == null) { 
                // 如果这个Entry是存在的, 但是key却已经是null了, 说明发生了内存泄露  stale(腐败的) 
                // 在遇到第一个腐败的entry的时候, 进入replaceStaleEntry, 清除i所在的
                // 这个run(前后两个null包围的这段entry)中的腐败的节点, 并且将插入
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
			// 到这里说明
			// 第一, key在map中不存在, 
			// 第二i是map中原定位置之后第一个为null的位置
            tab[i] = new Entry(key, value);
            int sz = ++size;
            
            if (!cleanSomeSlots(i, sz) && sz >= threshold) 
            // 从i往后检查log(sz)个, 如果都没有腐烂并且负载依然超过了阈值
                rehash(); // 首先遍历整个数组清理一遍stale的entry, 如果负载还是超过的1/2, 则扩容
        }

按照key进行get的过程, 就是对key的hash取余看对应位置是否有要找的元素, 可能遇到的问题是key不存在.

 private Entry getEntry(ThreadLocal<?> key) {
     int i = key.threadLocalHashCode & (table.length - 1);
     Entry e = table[i];
     if (e != null && e.get() == key)
         return e;
     else
     // 如果在当前位置没有找到对应的key, 有两种可能, 不存在或者因为哈希冲突需要继续往后找
         return getEntryAfterMiss(key, i, e); 
 }
 
 private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
     Entry[] tab = table;
     int len = tab.length;

     while (e != null) {
         ThreadLocal<?> k = e.get();
         if (k == key)
             return e;
         if (k == null) // 沿路找到entry不为null但是key为null的腐烂的元素, 顺便清理掉
             expungeStaleEntry(i);
         else
             i = nextIndex(i, len); // 循环往后
         e = tab[i];
     }
     // 找到第一个null还没找到, 说明key不存在
     return null;
 }

还有删的情况, 也是类似, 线性探测的找到entry后, 删掉value和entry, 并顺便清理从这里开始往后到第一个null位置的所有腐烂节点

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) { // 找到了
            e.clear(); // 将当前的referent设为null
            expungeStaleEntry(i); // 当前位置现在变成了一个stale的entry, 清除这个腐败的entry
            return;
        }
    }
    // 没有找到key, 不做任何事
}

上面我们已经见过几次对腐败entry的处理了, 接下来会对这几个方法进

过期key的处理

首先是在set方法中遇到的replaceStaleEntry方法. 前面说到, 当set的线性探测过程遇到第一个腐败的entry的时候, 会进入replaceStaleEntry, 清除i所在这个run(前后两个null包围的这段entry)中的腐败的节点, 并且将插入. 这个方法很长也很重要, 英文的说明很清楚, 我简单翻译一遍捋一下.

/**
  * @param  staleSlot index of the first stale entry encountered while
  *         searching for key.
  */
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
	
	// 从staleSlot出发向前直到找到第一个null的entry, 将遇到的最前面的stale的节点保存在slotToExpunge中
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first
    // 如果key存在的则从staleslot往后应该会找到这个key, 否则会找到第一个null的entry节点
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
        // 如果从staleSlot往后找到了key, 则更新它并将这个key这个位置的entry 和第一个腐败的entry交换
        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            // 如果往前找的时候没有staleSlot, 说明i以前是没有stale的节点的, 直接向后清理
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); // 先准确清除这个run中的腐败元素, 然后往后检查log(n)个
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        // 往前找的时候没有staleslot, 往后找遇到的staleSlot之后第一个stale的元素, 用slotToExpunge指向它
        // staleSlot一定会被处理的, slotToExpunge指向的是slot之后开始需要被清除的位置
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    // 如果没有找到key 直接把staleSlot位置鸠占鹊巢
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    // 如果向前或向后遇到了其他的stale的元素, 从那个位置开始处理向后的元素
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); // 先准确清除这个run中的腐败元素, 然后往后检查log(n)个
}

接下来是第二个出镜率很高的方法, expungeStaleEntry, 这个方法的功能是, 给定一个已知的腐烂的节点的index, 它将清除这个节点并将其后直到第一个null的节点之间的可能因为hash冲突排到不是自己位置的节点进行rehash, 这个过程中如果遇到了其他腐烂的节点也会清除, 最终返回的是index之后的第一个null的节点的index, 这个返回值和传入参数之间的部分全部是被检查过的

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null; // 直到遇到null节点
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            // 遇到了stale的 直接清除
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                // 重新从它自己hash对应的位置往后线性探测找到位置, 
                // 这里和我想的不一样, 我以为是会把线性冲突的元素往前移, 当前做法是O(n)的
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

另一个在set添加一个key和expunge一个stale的元素时会被调用的是cleanSomeSlots(i, n), 这个方法从i位置往后会检查log_2(n)次元素是否腐败, 如果有则会调用expunge清除, 并返回true, 否则返回false

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            // 注意这里如果有stale的元素, 就不再是往后检查logN个, 而是往后检查logN次, 因为i被替换意味着会跳过一个run的长度
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

此外还有一个expungeStaleEntries方法对全表进行一次检查, 并调用expunge清除腐败元素.

至此, ThreadLocal中处理内存泄露的清理方法就分析完了, 它们的功能在这里进行简单总结

  • replaceStaleEntry 在set方法的线性探测阶段, 遇到的第一个腐败的entry, 则会将这个entry替换成set要插入或修改的键值对, 并且在这个腐败entry所在的run中检查是否有其他腐败元素调用expunge一并删除
  • expungeStaleEntry清除这个节点并从这个节点往后直到第一个null的节点之间的可能因为hash冲突排到不是自己位置的节点进行rehash, 这个过程中如果遇到了其他腐烂的节点也会清除
  • cleanSomeSlots 从i位置往后会检查log_2(n)次元素是否腐败, 如果有则会调用expunge清除, 并返回true, 否则返回false

inheritableThreadLocals

最后再来简单说下Thread中的这个变量. 说这个变量前首先要介绍下InheritableThreadLocalThreadLocal的区别

当在一个线程中创建子线程时, 子线程是访问不了父线程中的ThreadLocal的, 但如果希望它能够继承父线程的这部分变量, 并且保存一个独立的副本呢, 就采用InheritableThreadLocal. 而inheritableThreadLocals就是Thread中InheritableThreadLocal存放的位置.

资料参考:

  • 第4讲 | 强引用、软引用、弱引用、幻象引用有什么区别?

你可能感兴趣的:(JavaSE基础与源码分析)