今天,我们来聊聊ThreadLocal
ThreadLocal是什么
ThreadLocal相信大家就算没有用过也听过,他可以为每个使用该变量的线程分配一个独立的变量副本,所以每一个线程都可以独立的改变自己的副本,不会影响到其他的线程。从线程的角度来看,这个变量就像是线程的本地变量,所以这也就是local要表达的意思。既然每个线程都有了自己的本地变量,所以也就不存在多线程共享变量的问题,他与synchronized关键字不同:
- synchronized是通过线程等待,以时间来换空间
- ThreadLocal是通过每个线程单独一份存储空间,以空间来换取时间
如何使用ThreadLocal
那么说了一大通,ThreadLocal究竟如何使用呢?其实他的使用方式很简单,如下所示,是ThreadLocal在spring事务管理中的应用:
在DataSourceTransactionManager中,doBegin()方法开启事务
@Override
protected void doBegin(Object transaction, TransactionDefinition definition) {
DataSourceTransactionObject txObject = (DataSourceTransactionObject) transaction;
Connection con = null;
try {
if (!txObject.hasConnectionHolder() ||
txObject.getConnectionHolder().isSynchronizedWithTransaction()) {
Connection newCon = obtainDataSource().getConnection();
if (logger.isDebugEnabled()) {
logger.debug("Acquired Connection [" + newCon + "] for JDBC transaction");
}
txObject.setConnectionHolder(new ConnectionHolder(newCon), true);
}
txObject.getConnectionHolder().setSynchronizedWithTransaction(true);
con = txObject.getConnectionHolder().getConnection();
Integer previousIsolationLevel = DataSourceUtils.prepareConnectionForTransaction(con, definition);
txObject.setPreviousIsolationLevel(previousIsolationLevel);
if (con.getAutoCommit()) {
txObject.setMustRestoreAutoCommit(true);
if (logger.isDebugEnabled()) {
logger.debug("Switching JDBC Connection [" + con + "] to manual commit");
}
con.setAutoCommit(false);
}
prepareTransactionalConnection(con, definition);
txObject.getConnectionHolder().setTransactionActive(true);
int timeout = determineTimeout(definition);
if (timeout != TransactionDefinition.TIMEOUT_DEFAULT) {
txObject.getConnectionHolder().setTimeoutInSeconds(timeout);
}
// 这里进行绑定
if (txObject.isNewConnectionHolder()) {
TransactionSynchronizationManager.bindResource(obtainDataSource(), txObject.getConnectionHolder());
}
}
catch (Throwable ex) {
if (txObject.isNewConnectionHolder()) {
DataSourceUtils.releaseConnection(con, obtainDataSource());
txObject.setConnectionHolder(null, false);
}
throw new CannotCreateTransactionException("Could not open JDBC Connection for transaction", ex);
}
}
首先从链接池中获取一个connection,然后开始事务,最后通过TransactionSynchronizationManager将connection与当前线程进行绑定。
public static void bindResource(Object key, Object value) throws IllegalStateException {
Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key);
Assert.notNull(value, "Value must not be null");
// resources就是一个ThreadLocal对象,value是一个map
Map
那在我们的项目中是如何使用的呢?如下所示:
public class DateUtil {
private static final ThreadLocal dateFormatter = ThreadLocal.withInitial(() ->
new SimpleDateFormat("yyyy-MM-dd"));
public static String addOneDay(String date) throws Exception {
if (StringUtils.isBlank(date)) {
return null;
}
SimpleDateFormat sf = dateFormatter.get();
Date origDate = sf.parse(date);
return sf.format(origDate.getTime() + 3600000 * 24);
}
}
使用SimpleDateFormat作为共享变量会存在线程安全问题,所以这里使用了ThreadLocal来为每个线程创建一个SimpleDateFormat对象副本,这样就解决了这个对象的线程安全问题,Java 8专门对日期处理类进行了优化,使用Java8的日期处理类是线程安全的。
ThreadLocal的原理
ThreadLocal原理图解
图中的实线为强引用,虚线为弱引用
set(T value)方法源代码解析
ThreadLocal set(T value)方法和ThreadLocalMap set(ThreadLocal> key, Object value)方法
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 如果map存在,则放入,否则,创建map放入
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
从上面的代码可以看出,每个线程只能持有一个ThreadLocalMap对象,那么我们先来看看ThreadLocalMap是什么东西:
static class ThreadLocalMap {
// 这里可以看到,Entry的key是一个弱引用
static class Entry extends WeakReference> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal> k, Object v) {
super(k);
value = v;
}
}
ThreadLocalMap(ThreadLocal> firstKey, Object firstValue) {
// 初始化容量和hashmap一样,INITIAL_CAPACITY = 16
table = new Entry[INITIAL_CAPACITY];
// 位运算,计算出第一个key存放的位置,threadLocalHashCode留一下,下面会说明
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
}
通过以上代码我们可以看出,ThreadLocalMap本质上就是一个map,他实例化的时候创建了一个长度为16的Entry数组,通过位运算得出一个i,这个i就是存储在table数组中的位置。如果我们同时实例化两个ThreadLocal对象,如下所示:
ThreadLocal a = new ThreadLocal<>();
ThreadLocal b = new ThreadLocal<>();
由前面的知识我们知道,同一个thread只会有一个ThreadLocalMap,那么为了管理a,b,于是把他们放在数组不同的位置,那么是如何进行存放的呢?我们看看ThreadLocalMap set()方法的源代码:
private void set(ThreadLocal> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 这里是重点,熟悉的位运算
int i = key.threadLocalHashCode & (len-1);
// 如果table[i]有值,则替换
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal> k = e.get();
if (k == key) {
e.value = value;
return;
}
// 这里留意下,后续会说明
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 如果没有,则创建新值
tab[i] = new Entry(key, value);
int sz = ++size;
// 如果满足条件,则进行扩容,跟hashmap扩容原理相似,都是重新hash
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
好,set方法咋一看挺简单的,位运算是重点,我们单独拿出来看看他是如何实现同一个线程存储不同的ThreadLocal对象:
private final int threadLocalHashCode = nextHashCode();
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
private static final int HASH_INCREMENT = 0x61c88647;
private static AtomicInteger nextHashCode = new AtomicInteger();
好,以上就是threadLocalHashCode这个变量的全部操作了,可以看到,在new ThreadLocal的时候,这个变量的值便固定了,第一次为0,第二次为0x61c88647,以后每次都加上这么一个数:0x61c88647,那么为什么是这个数不是其他数呢?根据官方的注释大概翻译就是这个值是斐波那契散列乘数,通过与他hash出来的结果分布会比较的均匀。通过以上分析就很清晰了,同一个线程,new多个ThreadLocal对象,new的时候,他的threadLocalHashCode便已经生成了,通过位运算,可以把他放到table的不同位置。
get()方法源代码解析
看完了set()方法,我们就来看看get()方法,就简单多了:
public T get() {
Thread t = Thread.currentThread();
// 获取ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
// 获取Entry
ThreadLocalMap.Entry e = map.getEntry(this);
// 不为null的话返回Entry的value
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
上面就是get()方法,是不是很简单?接下来,我们就深入的看看map.getEntry()方法:
private Entry getEntry(ThreadLocal> key) {
// 同样的位运算
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 在table数组中找到Entry后,返回Entry
if (e != null && e.get() == key)
return e;
else
// 这里后面说
return getEntryAfterMiss(key, i, e);
}
以上就是get()方法了,由set()方法的知识来看这个方法,还是挺简单的,接下来我们看看remove()方法
remove()方法源代码解析
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
remove()方法就更简单了,调用ThreadLocalMap的remove方法,来移除table的值:
private void remove(ThreadLocal> key) {
Entry[] tab = table;
int len = tab.length;
// 位运算,获取数组下标i
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// 这个clear()方法是弱引用的方法,就是把Entry的Key置为null
e.clear();
// 置为null后,value也置为空,size--
expungeStaleEntry(i);
return;
}
}
}
以上就是remove()方法,也挺简单的,到这里基本的源码分析就完结了,肯定有人说,上面不是有部分注释写的放在下面说吗?怎么就没了,放下面说是放在警惕内存泄漏模块,我们接着往下看。
警惕ThreadLocal内存泄漏
ThreadLocal为什么会内存泄漏
我们回到文章上面那种图解,图中虚线是一个弱引用,弱引用在gc的时候是会被回收的,那么key就变成了null,如果thread的生命周期比较长,或者说使用的是线程池,那么就会存在这么一条引用链:thread引用->current thread->ThreadLocalMap->Entry->value->Object,在gc可达性分析的时候,这条引用链是可达的,从而造成无法gc,造成内存泄漏。
ThreadLocal是如何处理内存泄漏的
那么既然有可能会造成内存泄漏,编写ThreadLocal的人肯定考虑到了,我们就来看看ThreadLocal是如何“自救”的,也就是看我们上面说的需要注意的几个地方,我们先看set()方法里的cleanSomeSlots()方法,为了方便查看,我们把set再拷贝下来:
private void set(ThreadLocal> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);=
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
我们看到,第一次执行set()方法,for循环是不会进去的,那么会直接执行下面的方法,下面有个cleanSomeSlots()方法:
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
这个方法什么意思呢?顾名思义,扫描看是否有脏Entry(key为null),这里的i是插入Entry的位置,n是table当前的真实的Entry个数也就是size,不是table的长度16。n还用来控制扫描的次数,while里的条件是(n >>>= 1) != 0,n无符号右移1位,直到n==0为止,所以扫描的次数是log2(n),如果在扫描的途中,遇到脏Entry,那么n就会变成table的长度也就是length,扩大搜索范围,我们看到遇到脏Entry的时候会有这个一个方法,expungeStaleEntry(i),名称就是清理脏Entry,我们来看看是如何清理的,其实也不难,入参的staleSlot 就是脏Entry的数组下标:
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 置为null,断开引用链,便于gc回收空间
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
// 不会就此停止,而是继续向后查找,知道table[i] == null结束,nextIndex方法为:
// private static int nextIndex(int i, int len) {
// return ((i + 1 < len) ? i + 1 : 0);
//}
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal> k = e.get();
// 如果再次遇到脏Entry,继续清理
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// 看下是否需要重新hash
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
看方法的注释其实基本已经明了了,这个方法首先会把i这个脏Entry给清理掉,清理完之后他并没有闲着,而是会向后去继续的清理脏Entry,直到table[i] == null为止,table[i] == null意味着Entry==null,肯定可不能是脏Entry。
那如果是第二次执行set()方法,是会进入set()方法里的for循环,如果Entry的key相等,那么直接覆盖并且return,那如果Entry是一个脏Entry,key == null,会进入replaceStaleEntry()这个方法,如下所示,这个方法的入参,key和value为set()方法的key和value,i就是脏Entry的数组下标:
private void replaceStaleEntry(ThreadLocal> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
int slotToExpunge = staleSlot;
// 向前查找脏Entry如果找到了,那么slotToExpunge就是脏Entry的下标位置
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal> k = e.get();
// 如果向后查找找到了k == key,则把value赋值给e.value,然后把e和脏Entry进行位置交换
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 如果向前没找到脏Entry,那么以i为下标,去向后查找脏Entry并清理Entry
if (slotToExpunge == staleSlot){
slotToExpunge = i;
}
// 如果向前找到了脏entry,那么以向前脏Entry的下标i向后查找并清理脏Entry,并return
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// 如果向后查找遇到了脏Entry,并且向前没遇到脏Entry,
// 那么slotToExpunge为脏Entry的数组下标
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 把入参的脏Entry value置为null
tab[staleSlot].value = null;
// 把入参脏Entry变成一个正常Entry
tab[staleSlot] = new Entry(key, value);
// 这个情况是for循环向后查找遇到了脏Entry,则以脏Entry下标为位置,向后查找并清理脏Entry
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
基本的解析都写在了方法里面,总共是有4种情况:
- 向前无脏Entry,向后无脏Entry,那么入参直接重建:tab[staleSlot] = new Entry(key, value)
- 向前有脏Entry,向后有可覆盖的Entry,那么把可覆盖的Entry value置为入参value,并且和当前脏Entry进行替换,并以向前脏Entry的i为下标,进行清理脏Entry
- 向前有脏Entry,向后无可覆盖的Entry,那么入参直接重建:tab[staleSlot] = new Entry(key, value),以向前的脏Entry下标i向后查找并清理脏Entry
- 向前无脏Entry,向后有可覆盖的Entry,那么把可覆盖的Entry value置为入参value,并且和当前脏Entry进行替换,并且以可覆盖的Entry下标i为位置,向后查找并清理脏Entry
说完了上面三个方法,其实在get(),remove的时候也是会调用这个三个方法,针对ThreadLocal内存泄漏的问题,就是通过这三个方法来解决的。
ThreadLocalMap的key为什么要使用弱引用
如果Entry的key使用的是强引用的话,在代码中执行:threadLocalInstance == null操作的时候,其实ThreadLocal对象的引用还是可达的,gc的时候进行可达性分析是无法回收掉的。
尽管弱引用会出现内存泄漏的问题,在ThreadLocal中其实提供了get(),set(),remove()进行了一定程度的规避,尽可能达到安全使用的状态。
如何避免内存泄漏
通过以上的源码分析,ThreadLocal的最佳实践已经很清晰了,有以下几点:
- 每次使用完ThreadLocal,调用remove()方法,清除数据。
- 如果使用线程池,一定要进行清理,不然就会导致业务逻辑的问题出现。
参考资料:
- ThreadLocal内存泄漏真因探究:https://www.jianshu.com/p/a1cd61fa22da
- ThreadLocal:https://www.jianshu.com/p/3c5d7f09dfbd
- 一篇文章,从源码深入详解ThreadLocal内存泄漏问题:https://www.jianshu.com/p/dde92ec37bd1