接上一篇《java并发系列(3)——线程协作:wait,notify,join》
在多线程场景下,针对并发访问共享变量的问题:
示例代码:
package per.lvjc.concurrent.threadlocal;
public class ThreadLocalTest {
private static class IntThreadLocal extends ThreadLocal<Integer> {
@Override
protected Integer initialValue() {
return 0;
}
}
private static final IntThreadLocal threadLocal = new IntThreadLocal();
public static void main(String[] args) {
Runnable runnable = () -> {
int i = 0;
while (i++ < 10000) {
int value = threadLocal.get();
value = value + 1;
threadLocal.set(value);
}
System.out.println(Thread.currentThread().getName() + ":" + threadLocal.get());
};
new Thread(runnable, "thread1").start();
new Thread(runnable, "thread2").start();
}
}
这里两个线程使用的是同一个 Runnable 对象,同一个 ThreadLocal 对象,各自从 ThreadLocal 对象中取值自加 10000 次。
执行结果:
thread1:10000
thread2:10000
说明不同的线程从同一个 ThreadLocal 对象中取值做运算再写入,是互不干扰的。
在单线程的场景下,也可以把 ThreadLocal 作为线程上下文变量来用。
示例代码:
package per.lvjc.concurrent.threadlocal;
public class ThreadContext {
private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();
public static void main(String[] args) {
threadLocal.set("some values");
print();
}
private static void print() {
printA();
}
private static void printA() {
printB();
}
private static void printB() {
printC();
}
private static void printC() {
printX();
}
private static void printX() {
//旧业务逻辑
System.out.println("old");
//添加新的业务逻辑
String value = threadLocal.get();
System.out.println("new:" + value);
}
}
比如这里有个 printX 方法,现在需要添加新的业务逻辑,而新的业务逻辑需要用到一个 value 变量,但这个变量方法入参里又没有。
怎么办?给 printX 方法加个入参,让上一层把 value 变量传进来。于是找到 printC 方法,让 printC 方法传一个 value 变量,但 printC 方法也没有 value 变量。于是 printC 方法也要加一个 value 入参,再让上一层传进来…最后一直往上找到了 Main 方法。结果发现,就为了拿到一个变量增加一点简单的逻辑,改了一万行代码。
那么,这时候 ThreadLocal 就派上用场了,定义一个 ThreadLocal 变量,在最开始的 Main 方法把 value 设进去,在 printX 方法就可以直接拿到了,中间一层一层的方法调用一个都不用改。当然,前提是 Main 方法和 printX 方法在同一个线程。
见前面示例代码。
如果不先 set,直接 get 会得到 null。
清除 ThreadLocal 变量中设置的值。
这个方法很重要,如果 set 值用完之后不及时 remove,可能会导致:
ThreadLocal 变量不先 set 直接 get 会得到初始值,默认的初始值是 null。
设置初始值示例:
package per.lvjc.concurrent.threadlocal;
public class ThreadLocalTest {
private static class IntThreadLocal extends ThreadLocal<Integer> {
@Override
protected Integer initialValue() {
return 0;
}
}
private static final IntThreadLocal threadLocal = new IntThreadLocal();
}
扩展 ThreadLocal 类,override initialValue 方法。
或者简单写成匿名类的方式:
private static final ThreadLocal<Integer> t = new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return 0;
}
};
或者再简化成:
private static final ThreadLocal<Integer> th = ThreadLocal.withInitial(() -> 0);
正常情况 initialValue 方法一个线程最多只会调用一次。
简单看下这部分的源码:
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
这是唯一调用 initialValue 方法的地方,会在未调用 set 方法之前,第一次调用 get 方法的时候触发。
不过,使用“非正常”的手段也可以让 initialValue 方法调用多次:
package per.lvjc.concurrent.threadlocal;
import java.lang.reflect.Field;
public class ClearThreadLocalTest {
private static final ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 0);
public static void main(String[] args) throws NoSuchFieldException, IllegalAccessException {
System.out.println(threadLocal.get());
threadLocal.set(null);
System.out.println(threadLocal.get());
clearThreadLocal();
System.out.println(threadLocal.get());
}
private static void clearThreadLocal() throws NoSuchFieldException, IllegalAccessException {
Thread currentThread = Thread.currentThread();
Field threadLocalMap = Thread.class.getDeclaredField("threadLocals");
threadLocalMap.setAccessible(true);
threadLocalMap.set(currentThread, null);
}
}
执行结果:
0
null
0
可以看到在调用了 clearThreadLocal 方法之后,再去 get,又触发了一次 ThreadLocal 初始值的设置。当然,这里用反射就比较赖皮了,正常情况最多只会被调用一次。
这张图的含义在下面结合源码来解释。
ThreadLocal 的数据存储结构在上面的图中已经表示出来了,关键就是那个 ThreadLocalMap,看源码:
public
class Thread implements Runnable {
//...
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
/*
* InheritableThreadLocal values pertaining to this thread. This map is
* maintained by the InheritableThreadLocal class.
*/
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
//...
}
可以看到,ThreadLocal 数据(threadLocals 成员变量)实际上是被 Thread 对象持有的,数据类型就是 ThreadLocal.ThreadLocalMap。
再看 ThreadLocal.ThreadLocalMap 是什么:
public class ThreadLocal<T> {
//...
static class ThreadLocalMap {
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
private static final int INITIAL_CAPACITY = 16;
private Entry[] table;
private int size = 0;
private int threshold; // Default to 0
private static int nextIndex(int i, int len) {}
private static int prevIndex(int i, int len) {}
private Entry getEntry(ThreadLocal<?> key) {}
private void set(ThreadLocal<?> key, Object value) {}
private void remove(ThreadLocal<?> key) {}
private void rehash() {}
private void resize() {}
}
}
这里只贴出来 ThreadLocalMap 类中的部分方法,其它的省略了。
从源码可以看出,ThreadLocalMap 是 ThreadLocal 类中的一个内部类,ThreadLocalMap 中又有一个内部类 Entry。
看 ThreadLocalMap 中的这些变量和方法,很明显:
很典型的 hash 结构。
到此就有了上面那张图:
从前面讲的 ThreadLocal 的数据存储结构可以知道,每个 Thread 都有一个自己的 map,那么就有一个问题:当我们通过 Thread#set 方法设值时,这个变量被放到了哪个 map 里面?
看源码:
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
很清晰:
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
也很清晰:
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
也是很简单的 map 操作。
到这里,已经可以自己实现一个简易版的 ThreadLocal 了,详情参看《手写ThreadLocal》。
前面已经看了一部分源码,可以知道:
所以,ThreadLocal 在运行时占用的内存如上图所示。
ThreadLocal 获取 value 的过程:
在这个过程中,如果 ThreadLocal 对象被回收了,那么这个 ThreadLocal 对象计算出来的 index 位置的 Entry 就再也找不到了。但如果 Thread 对象还在,这个 Entry 就回收不掉,因为从 Thread 到 Entry 一路都有强引用。于是就出现一个回收不掉又找不到的 Entry,这块内存就泄漏了。
ThreadLocal 内存泄漏的本质:map 数据结构本身就存在内存泄漏的风险。
比如,以 HashMap 为例:
package per.lvjc.concurrent.threadlocal;
import java.util.HashMap;
import java.util.Map;
public class MapMemoryLeak {
private static Map<Object, byte[]> map = new HashMap<>();
private static void put(byte[] value) {
map.put(new Object(), value);
}
public static void main(String[] args) {
int i = 0;
while (true) {
put(new byte[1024 * 1024]);
System.out.println("i = " + i++);
}
}
}
如果像这样错误地使用 HashMap,那 HashMap 也会内存泄漏。
执行结果:
...
i = 14
i = 15
i = 16
Exception in thread "main" java.lang.OutOfMemoryError: Java heap space
at per.lvjc.test.MapMemoryLeak.main(MapMemoryLeak.java:17)
Process finished with exit code 1
这里因为 HashMap 不会被 gc,所以 HashMap 里面的 key,value 也都没法被 gc,最后就内存溢出了。
但这个 HashMap 里面的数据又没法用,因为 key 的引用已经找不回来了,就算遍历整个 HashMap 也没法知道哪个 value 是我想要的。
所以,ThreadLocal 发生内存泄漏的条件是:
(1)value 还在
这包含两个条件:
(2)key 找不到
根据以上分析的 ThreadLocal 会发生内存泄漏的条件,使用 ThreadLocal 时应该注意:
提到 ThreadLocal 就不得不提它里面用到的一个弱引用。此前一直没有提这个,是因为个人认为 ThreadLocal 潜在的内存泄漏风险跟这个弱引用没有一点关系。
先看一下弱引用在什么地方用到了。
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
Entry 继承了 WeakReference,它是一个指向 ThreadLocal 对象的弱引用。
弱引用是对垃圾收集器的一个提示:“下次 gc 的时候,如果没有其它强引用,请把这个对象回收掉,我已经不用了。”
这意味着,如果栈内存中对 ThreadLocal 对象的强引用断开,那么下次 gc 的时候,ThreadLocal 对象就会被回收掉。
Entry 是 ThreadLocalMap 中的一个内部类,所以先看一下 ThreadLocalMap,搞清楚这里为什么要特别实现一种新的 map。
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
普通的 map 在向里面存放数据的时候一般是,根据 key 计算出当前值应该放在数组的哪个位置,然后把数据放进去,要么新建,要么覆盖。
但 ThreadLocalMap 不一样:
if (k == key) {
e.value = value;
return;
}
这块是一般 map 的逻辑:
如果计算出来的这个位置上还没有数据,也就是e == null
,那么跳出 for 循环,执行:
tab[i] = new Entry(key, value);
在这个位置上新建并存入一个新的元素。
这些都是一般 map 存放数据时的逻辑,而 ThreadLocalMap 多做了一个操作,在 for 循环里面:
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
这块代码会得到执行的场景是:
k == null
);所以,ThreadLocalMap 在插入数据的时候,会做一个判断,如果发现这个位置内存泄漏了,它会做一个操作:
replaceStaleEntry(key, value, i);
看这个方法它干了什么:
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// Back up to check for prior stale entry in current run.
// We clean out whole runs at a time to avoid continual
// incremental rehashing due to garbage collector freeing
// up refs in bunches (i.e., whenever the collector runs).
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
// Find either the key or trailing null slot of run, whichever
// occurs first
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// If we find key, then we need to swap it
// with the stale entry to maintain hash table order.
// The newly stale slot, or any other stale slot
// encountered above it, can then be sent to expungeStaleEntry
// to remove or rehash all of the other entries in run.
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// Start expunge at preceding stale entry if it exists
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// If we didn't find stale entry on backward scan, the
// first stale entry seen while scanning for key is the
// first still present in the run.
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
这个方法稍微有点复杂,大意是:
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// expunge entry at staleSlot
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
// Rehash until we encounter null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
主要看这几行:
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//...
}
}
当发现k == null
即 key 已经被回收掉的时候,就把 key 对应的 value 的引用也给断掉,然后再把数组对这个位置的元素的引用也断掉。这样下次 gc 的时候,这块泄漏的内存就能被回收了,也就解决了内存泄漏的问题。
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
普通 map 的做法一般是,根据 key 计算出这个值是存放在哪个位置的,然后到数组里面取出这个位置的元素返回,如果为空就返回空。
ThreadLocalMap 不一样,从源码可以看到,当这个位置的元素为空,或e.get() != key
时,它额外做了一个 getEntryAfterMiss 的操作。
e.get() 为什么会不等于 key 呢?因为这里是弱引用,如果 key(即 ThreadLocal)没有被回收,正常情况 e.get() 就等于 key;而如果 key 被回收,那么 e.get() == null,这时候就走到了 else 里面的逻辑。
这里说的是正常情况,e.get() == key,可能存在一种情况是,两个 key 计算出来的 i 是相同的,这两个元素不可能都放在 i 位置上,于是必然存在一个 key,它计算出来了一个 i 位置,但这个位置却被其它元素抢占了,就会出现 e.get() != key 的情况。当然,这是 map 数据结构的 hash 问题,不是这里的重点。
下面看 getEntryAfterMiss 方法:
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
关键看这两行代码:
if (k == null)
expungeStaleEntry(i);
跟前面 set 方法一样,getEntry 方法也会对 key 进行校验,如果发现 key 被回收掉了,同样还是调用 expungeStaleEntry 方法扫描回收一部分可能已经泄漏的内存。
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
remove 方法也一样会调用 expungeStaleEntry 方法来回收一部分可能已经泄漏的内存。
这里使用弱引用与强引用的对比:
另一方面,从前面对 ThreadLocalMap 源码的分析来看,ThreadLocalMap 恰好可以借助这个弱引用来判断内存是否已经泄漏了:
所以,弱引用不是 ThreadLocal 内存泄漏的原因,相反,它是 JDK 针对内存泄漏而做的优化。
如果把这里的弱引用改为强引用,ThreadLocal 就不可能被回收,key 就不可能为 null,ThreadLocalMap 里面多处存在的当 key == null 时 expungeStaleEntry 的操作就根本不会得到执行,ThreadLocalMap 的整套机制就崩了。
虽然前面已经讲了,借助于弱引用,ThreadLocalMap 能判断是否内存已经泄漏了,从而把相应的 value 释放掉。
但我们能注意到,ThreadLocalMap 释放泄漏的内存的时机是:
那问题来了,如果没那么巧呢?
在 set,get 的时候没有恰好操作 key 为 null 的位置,而是对一个正常的位置做操作,那么 ThreadLocalMap 就不会去扫描是否有些位置内存已经泄漏了。(remove 就不谈了,如果会调 remove 方法,就不会内存泄漏了。)
而在这么巧的事情发生之前,已经泄漏的内存就只能先泄漏着。
所以,ThreadLocal 虽然通过 ThreadLocalMap 已经对潜在的内存泄漏风险做了处理,但未必会及时。