ThreadLocal

ThreadLocal类提供了线程局部 (thread-local) 变量。这些变量与普通变量不同,每个线程都可以通过其 get 或 set方法来访问自己的独立初始化的变量副本。ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。


image.png

ThreadLocal的源码分析

Thread类中有个变量threadLocals,类型为ThreadLocal.ThreadLocalMap,这个就是保存每个线程的私有数据。

public
class Thread implements Runnable {

   ThreadLocal.ThreadLocalMap threadLocals = null;
}

image.png

首先,主线程定义的两个ThreadLocal变量,和两个子线程——线程A和线程B。
线程A和线程B分别持有一个ThreadLocalMap用于保存自己独立的副本,主线程的ThreadLocal中封装了get()和set()之类的方法。
在线程A和线程B中调用ThreadLocal的set方法,会首先通过getMap(Thread.currentThread)获得线程A或者是线程B持有的ThreadLocalMap,在调用map.put()方法,并将ThreadLocal作为key。
get()方法和set()方法原理类似,也是先获取当前调用线程的ThreadLocalMap,再从map中获取value,并将ThreadLocal作为key。

ThreadLocalMap是ThreadLocal的内部类,每个数据用Entry保存,其中的Entry继承与WeakReference,用一个键值对存储,键为ThreadLocal的引用。为什么是WeakReference呢?如果是强引用,即使把ThreadLocal设置为null,GC也不会回收,因为ThreadLocalMap对它有强引用。

        static class Entry extends WeakReference> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
        }

ThreadLocal中的set方法的实现逻辑,先获取当前线程,取出当前线程的ThreadLocalMap,如果不存在就会创建一个ThreadLocalMap,如果存在就会把当前的threadlocal的引用作为键,传入的参数作为值存入map中。

  public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
}

不难看出,先获取当前线程的Thread对象,再得到该Thread对象的ThreadLocalMap 成员map,若map为空,需要先createMap()方法,若不为空,则需要调用map的set()方法

void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
}
private void setThreshold(int len) {
            threshold = len * 2 / 3;
}

createMap方法会创建一个ThreadLocalMap对象,在ThreadLocalMap(ThreadLocal firstKey, Object firstValue)构造方法中,可以看出和HashMap很相似,通过firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)取模,计算出哈希表的下标,将创建好的Entry对象放入该位置,再根据表长计算阈值,可以看出负载因子是2/3,初始哈希表的大小是16。

private void set(ThreadLocal key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

不难看出,通过key.threadLocalHashCode & (len-1)计算出哈希表的下标,判断该位置的Entry是否为null,若为null,则创建Entry对象,将其放入该下标位置;若Entry已存在,则需要解决哈希冲突,重新计算下标。最后size自增,再根据!cleanSomeSlots(i, sz) && sz >= threshold进行判断是否需要进行哈希表的调整。

在解决哈希冲突的上,常用的有开链法、线性探测法和再散列法,HashMap中使用的是开链法,而ThreadLocal使用的是线性探测法,即发生哈希冲突,往后移动到合适位置。

private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
}
private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
}

从这两个操作看出,ThreadLocal中的哈希表是利用了循环数组的方式,进行环形的线性探测
在上述for循环中,会取出该Entry上的ThreadLocal对象(键)进行判断,若相同则直接覆盖,若为null,说明该Entry空间存在但其ThreadLocal对象的指向为null,需要进行调整;若都不成立,则继续循环,重复以上操作。

Entry空间指向存在但ThreadLocal对象的指向为null是因为Entry继承自WeakReference>,是弱引用,存在被GC的情况,所以会存在这种情况,视为脏Entry,接下来的操作就是通过replaceStaleEntry进行处理。

private void replaceStaleEntry(ThreadLocal key, Object value,
                                       int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal k = e.get();

        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

可以清楚看到第一个for循环前向遍历查找脏Entry,用slotToExpunge保存脏Entry下标;
第二个for循环后向遍历,若遇到ThreadLocal向同,更新value,然后与下标为staleSlot(传入进来的脏Entry)进行交换,接着判断前向查找脏Entry是否存在,slotToExpunge == staleSlot说明的就是前向查找没找到,就更改slotToExpunge的值,然后进行清理操作,结束掉;若后向遍历遇到脏Entry,并且前向没找到,更改slotToExpunge的值,为清理时用,继续循环。
若不存在和ThreadLocal引用相同的Entry,则需要将staleSlot的位置的Entry替换为一个新的Entry对象,tab[staleSlot].value = null是为了GC;
最后根据slotToExpunge来判断前向后向遍历中是否存在脏Entry,若存在还需要进行清理。

其中的expungeStaleEntry方法如下

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

可以看到,先把当前位置的脏Entry清除掉(置为null),size自减。然后从当前位置后向遍历,若遇到脏Entry直接清除,size自减;若不是脏Entry,则需要判断它是否经过哈希冲突的调整的,若调整过,需要将其重新调整,最后返回当前位置为null的table下标;综上,该方法就是后向清除脏Entry,再把调整需要调整的Entry。

在replaceStaleEntry方法中,调用expungeStaleEntry清除掉脏Entry后,还要用cleanSomeSlots方法清除掉返回回来的下标后的脏Entry;

cleanSomeSlots方法:

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

从下标为i后面的开始后向遍历,遇到脏Entry调用expungeStaleEntry清除掉,令removed为true,i会变为下标为null的位置,继续循环;其中n的用途是控制循环次数,当遇到脏Entry时,会令n等于表长,扩大搜索范围。

在set方法中,最后根据!cleanSomeSlots(i, sz) && sz >= threshold,判断是否清理掉了脏Entry,若清理了什么都不做;若没有清理,还会判断是否达到阈值,进而是否需要rehash操作;

rehash方法:

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

首先调用expungeStaleEntries方法:

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

可以看到expungeStaleEntries方法是遍历整个哈希表,通过调用expungeStaleEntry方法清除掉所有脏Entry。
由于清除掉了脏Entry,还需要对size进行判断,看是否达到了阈值的3/4(提前触发resize),来判断是否真的需要resize;

resize方法:

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

刚开始的操作可以清楚的明白,每次扩容的大小都是原来的两倍;然后遍历原表的所有Entry,遇到脏Entry直接赋值null引起帮助GC;遇到有效Entry则需要根据新的表长重新计算下标,再通过线性探测完成新表的填充;填充完毕,计算新的阈值,给size和table赋值,结束操作。

至此,有关set的操作就结束了,还剩下get和remove:

get方法:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

和set一样,先获取当前线程,再根据当前线程获取其ThreadLocalMap成员map;
若map不为null,通过map的getEntry方法得到Entry对象,若Entry不为null则直接返回Entry的value;
若map为null,或者map不为null,但是Entry是null,则都需要调用setInitialValue方法。

getEntry方法:

private Entry getEntry(ThreadLocal key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

根据ThreadLocal定位哈希表的下标,若满足则直接返回,若不是,调用getEntryAfterMiss继续找。

getEntryAfterMiss方法:

private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

看以看到这还是一个后向遍历的查找,若是找到则直接返回;若遇到脏Entry需要调用expungeStaleEntry方法清理掉;最后还没找到返回null。

setInitialValue方法:

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
       map.set(this, value);
    else
       createMap(t, value);
    return value;
}

先调用initialValue方法,该方法需要使用者进行覆盖,否则返回的是null。所以当没有使用set方法时覆盖initialValue方法时还是会调用set方法的,效果是一样的。

protected T initialValue() {
         return null;
 }

后面的操作就和set方法一样。get方法至此结束。

remove方法:

 public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread());
     if (m != null)
         m.remove(this);
 }


以当前线程为参数调用getMap方法:

ThreadLocalMap getMap(Thread t) {
     return t.threadLocals;
 }

若是当前线程的ThreadLocalMap对象不存在,什么都不做,若存在,调用内部的remove方法:

private void remove(ThreadLocal key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

首先根据ThreadLocal找到其对应的的哈希表的下标(不一定是它的下标,会有哈希冲突的可能性),然后开始后向遍历,找到真正的位置,调用clear方法删除掉,顺便还进行脏Entry的清理。

clear方法是Reference类的方法:

public void clear() {
     this.referent = null;
 }

可以看到仅仅只是令指向变为null,因为Reference是WeakReference的父类,ThreadLocalMap继承自WeakReference>,弱引用变为null,就会变成脏Entry,所以就需要expungeStaleEntry对其清理。为什么不令tab[i]直接为null,就是因为在expungeStaleEntry执行时还会清理遇到的脏Entry,这样可以尽可能多的删除掉脏Entry。

你可能感兴趣的:(ThreadLocal)