ThreadLocal 实现原理

ThreadLocal 的作用是,对某一共享变量的操作,仅对当前线程可见

public static void main(String[] args) throws InterruptedException {
  ThreadLocal<String> threadLocal = new ThreadLocal<>();
  Thread thread = new Thread(() -> {
    System.out.println(threadLocal.get());    // null
    threadLocal.set("thread");
    System.out.println(threadLocal.get());    // thread
  });
  thread.start();
  thread.join();
  
  System.out.println(threadLocal.get());    // null  其他线程对 threadLocal 的修改,在另外一个线程中,不可见
  threadLocal.set("main");
  System.out.println(threadLocal.get());    // main
}

set 方法

public class ThreadLocal<T> {

  static class ThreadLocalMap {
    static class Entry extends WeakReference<ThreadLocal<?>> {
      //......  
    }
    private void set(ThreadLocal<?> key, Object value) {  
      //......  
    }
  }

  public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);    // ThreadLocalMap getMap(Thread t) { return t.threadLocals; }
    if (map != null)
      map.set(this, value);    // 调用 ThreadLocal.ThreadLocalMap 下的 set
    else
      createMap(t, value);
  }
}

Thread 类中有一个 threadLocals的成员变量,其类型为ThreadLocal.ThreadLocalMap

public class Thread implements Runnable {
    ThreadLocal.ThreadLocalMap threadLocals = null;
}

通过 threadLocals ,可将 Thread 与 ThreadLocal 关联起来

你可以这样理解:ThreadLocal 中的 set 操作的都是 threadLocals ,也就是,只操作当前线程的 threadLocals ,因此,对其他线程不可见

ThreadLocalMap 上有个 table 的数组,用于存放 Entry (数组中的元素要么为 Entry、要么为 null)

Entry 中:

  • key 为指向 ThreadLocal 实例的弱引用
  • value 为 ThreadLocal 的初始化值当前线程中 set 方法的值

同一线程下的 ThreadLocal 共用同一个 ThreadLocalMap(通过 Thread 下的 threadLocals 变量)

不同线程,不同 threadLocals -> 不同 ThreadLocalMap

相同线程,共用 threadLocals -> 相同线程下,不同 ThreadLocal,都是存放在 ThreadLocalMaptable

createMap

public class ThreadLocal<T> {
  void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
  }
  
  static class ThreadLocalMap {
    private Entry[] table;    // <------ 给 table 初始化
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
      table = new Entry[16];
      // firstKey.threadLocalHashCode = new AtomicInteger().getAndAdd(0x61c88647)
      int i = firstKey.threadLocalHashCode & (16- 1);
      table[i] = new Entry(firstKey, firstValue);
      size = 1;
      setThreshold(INITIAL_CAPACITY);
    }
  }
}

在 createMap 函数中执行 new ThreadLocalMap 的时候,将 this 传入,即将 ThreadLocal 实例作为参数,该参数之后通过 hash 计算,可得到 table 中对应索引

通过斐波那契散列,可使分布更加均匀

(32位整数)hash_increment: 2 32 × 0.618 = − 1640531527 ( 十进制 ) = 61 c 88647 ( 十六进制 ) 2 ^ {32} \times 0.618 = -1640531527(十进制) = 61c88647(十六进制) 232×0.618=1640531527(十进制)=61c88647(十六进制)


table 数组中存放的是 Entry 对象,该对象的 key 是个弱引用

static class Entry extends WeakReference<ThreadLocal<?>> {
  Object value;
  Entry(ThreadLocal<?> k, Object v) {
    super(k);    // <----- k 为弱引用
    value = v;
  }
}

调用过 ThreadLocal.set() 方法后,每个线程都会维护一个类型为 ThreadLocalMap 的成员变量

该成员变量名为 threadLocals ,内部有个 table 的数组,存放 Entry 的对象

weakreference

static Object object = new Object();
public static void main(String[] args) {
  WeakReference<Object> objectWeakReference = new WeakReference<>(object);
  objectWeakReference = null;
  System.gc();
  System.out.println(objectWeakReference.get());
}

Exception in thread “main” java.lang.NullPointerException
at org.example.zhang.Demo.main(Demo.java:19)

map.set(this, value)

private void set(ThreadLocal<?> key, Object value) {    // key 为 ThreadLocal 的 this
  // We don't use a fast path as with get() because it is at
  // least as common to use set() to create new entries as
  // it is to replace existing ones, in which case, a fast
  // path would fail more often than not.
  // 我们不像 get() 那样使用 fast path ,因为使用 set() 创建新 entries 与替换现有 entries 一样普遍,在这种情况下,fast path 往往会失败.

  Entry[] tab = table;
  int len = tab.length;
  int i = key.threadLocalHashCode & (len - 1);		// 计算对应槽位的索引

  for (Entry e = tab[i];
       e != null;           // 槽位上为空,则跳出循环(此时可直接在对应槽位上更新,没必要往后遍历了)
       e = tab[i = nextIndex(i, len)]) {    // 情况三:槽位上不为空,但 Entry 上的 key 与当前 ThreadLocal 不相等,进入下一次循环
    ThreadLocal<?> k = e.get();    // 返回 Entry 的 key,实际调用 Reference#get{  return this.referent;  }(WeakReference extends Reference)

    if (k == key) {    // 情况一:槽位上不为空,Entry 上的 key 与当前 ThreadLocal 相等,覆盖
      e.value = value;
      return;
    }

    if (k == null) {    // 情况二:槽位上不为空,而 Entry 上的 key 为空,说明 key(弱引用)被回收了
      replaceStaleEntry(key, value, i);
      return;
    }
  }

  tab[i] = new Entry(key, value);    // 情况四:槽位上为空,直接更新
  int sz = ++size;
  if (!cleanSomeSlots(i, sz) && sz >= threshold)
    rehash();    // 扩容
}

private static int nextIndex(int i, int len) {
  return ((i + 1 < len) ? i + 1 : 0);
}

通过 threadlocal 实例的 hash 值可定位到 table 的具体槽位

  • 情况一:槽位上不为空,Entry 上的 key 与当前 ThreadLocal 相等,覆盖
  • 情况二:槽位上不为空,而 Entry 上的 key 为空,说明 key(弱引用)被回收了,执行 replaceStaleEntry 方法
    • 后面的 Entry key 相等:交换
    • 后端的槽位上为空:在当前 Entry key 为 null 的槽位上直接更新
  • 情况三:槽位上不为空,但 Entry 上的 key 与当前 ThreadLocal 不相等,进入下一次循环
  • 情况四:槽位上为空,直接更新

ThreadLocal 实现原理_第1张图片

replaceStaleEntry

情况二:槽位上不为空,而 Entry 上的 key 为空,说明 key(弱引用)被回收了,执行 replaceStaleEntry 方法

当前位置的 key 已被回收,ThreadLocal 认为这些 key 是无效的,需替换和清理

向前遍历

向前遍历,寻找过期 key,即 Entry 为 null 的槽位(遍历过程中,若遇到槽位上为空的情况,则结束向前遍历)

  • 找到:slotToExpunge = i;
  • 遍历结束后未找到过期 key:因为 prevIndex 为 (i - 1 >= 0) ? i - 1 : len - 1,i = 0 会返回 len -1,即数组中最后一个元素
    这里有个疑问,若整个数组内的元素都不为 null,则永远无法满足 (e = tab[i]) != null 条件,是否会发生死循环?

向后遍历

  • 槽位上为空 :跳出循环,在当前槽位上直接更新
  • Entry key 相等:交换
  • Entry key 不相等:若 Entry key 为 null(过期 key),且 slotToExpunge = staleSlot,则更新 slotToExpunge,即 slotToExpunge = i, 然后进入下一次循环;否则直接进入下一次循环
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
  Entry[] tab = table;
  int len = tab.length;
  Entry e;

  // Back up to check for prior stale entry in current run.
  // We clean out whole runs at a time to avoid continual
  // incremental rehashing due to garbage collector freeing
  // up refs in bunches (i.e., whenever the collector runs).
  // 备份以检查当前运行中的先前的stale(过期)entry。我们一次清理整个运行,以避免由于垃圾收集器释放成束的引用(即,每当收集器运行时)而导致的持续增量的重新散列。
  int slotToExpunge = staleSlot;
  for (int i = prevIndex(staleSlot, len);    // int prevIndex(int i, int len) {  return ((i - 1 >= 0) ? i - 1 : len - 1)  }
       (e = tab[i]) != null;
       i = prevIndex(i, len))
    if (e.get() == null)
      slotToExpunge = i;    // expunge:删除

  // Find either the key or trailing null slot of run, whichever
  // occurs first
  // 查找运行的 key 或尾部为 null 的空槽,以先发生者为准
  for (int i = nextIndex(staleSlot, len);
       (e = tab[i]) != null;
       i = nextIndex(i, len)) {
    ThreadLocal<?> k = e.get();

    // If we find key, then we need to swap it
    // with the stale entry to maintain hash table order.
    // The newly stale slot, or any other stale slot
    // encountered above it, can then be sent to expungeStaleEntry
    // to remove or rehash all of the other entries in run.
    // 如果我们找到key,那么我们需要将它与过时的entry交换以维护哈希表的顺序。
    // 然后可以将新的stale slot或在其上方遇到的任何其他stale slot发送到 expungeStaleEntry 以删除或重新散列运行中的所有其他条目。
    if (k == key) {    // Entry key 相等:交换
      e.value = value;

      tab[i] = tab[staleSlot];    // 交换
      tab[staleSlot] = e;

      // Start expunge at preceding stale entry if it exists
      // 如果存在,则从先前的stale entry开始删除
      if (slotToExpunge == staleSlot)
        slotToExpunge = i;
      cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
      return;
    }

    // If we didn't find stale entry on backward scan, the
    // first stale entry seen while scanning for key is the
    // first still present in the run.
    // 如果我们在反向扫描中没有找到stale entry,那么在扫描key时看到的第一个stale entry是第一个仍然存在于运行中的entry。
    if (k == null && slotToExpunge == staleSlot)
      slotToExpunge = i;
  }

  // If key not found, put new entry in stale slot    未找到,在当前槽位上直接更新
  tab[staleSlot].value = null;
  tab[staleSlot] = new Entry(key, value);

  // If there are any other stale entries in run, expunge them
  // 如果运行中有任何其他 stale entries,删除它们
  if (slotToExpunge != staleSlot)
    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

cleanSomeSlots

  • 在前向遍历的时候,如果还未找到过期 key,遇到了槽位上为 null,则会结束循环,符合 slotToExpunge == staleSlot 条件
    向后遍历时,在交换的情况,之后会执行 cleanSomeSlots(expungeStaleEntry(i), len);,i为要交换的索引(后面的索引)
  • 在前向遍历的时候,如果找到过期 key
    向后遍历时,遇到槽位上为空,直接更新的情况,之后会执行 cleanSomeSlots(expungeStaleEntry(slotToExpunge), len) 方法
private int expungeStaleEntry(int staleSlot) {
  ThreadLocal.ThreadLocalMap.Entry[] tab = table;
  int len = tab.length;

  // expunge entry at staleSlot
  tab[staleSlot].value = null;
  tab[staleSlot] = null;
  size--;

  // Rehash until we encounter null
  ThreadLocal.ThreadLocalMap.Entry e;
  int i;
  for (i = nextIndex(staleSlot, len);
       (e = tab[i]) != null;
       i = nextIndex(i, len)) {
    ThreadLocal<?> k = e.get();
    if (k == null) {  // Entry key 为空,清空过期的 Entry,将其设为 null
      e.value = null;
      tab[i] = null;
      size--;
    } else {    // Entry key 不为空
      int h = k.threadLocalHashCode & (len - 1);     // rehash,使得原本槽位与当前槽位之间的距离更加近
      if (h != i) {
        tab[i] = null;

        // Unlike Knuth 6.4 Algorithm R, we must scan until
        // null because multiple entries could have been stale.
        while (tab[h] != null)
          h = nextIndex(h, len);
        tab[h] = e;
      }
    }
  }
  return i;
}

private boolean cleanSomeSlots(int i, int n) {
  boolean removed = false;
  ThreadLocal.ThreadLocalMap.Entry[] tab = table;
  int len = tab.length;
  do {
    i = nextIndex(i, len);
    ThreadLocal.ThreadLocalMap.Entry e = tab[i];
    if (e != null && e.get() == null) {
      n = len;
      removed = true;
      i = expungeStaleEntry(i);
    }
  } while ((n >>>= 1) != 0);
  return removed;
}

探测式清理流程:从开始位置向后清理数据,清理过程中:(遇到槽位上为 null,跳出循环)

  • Entry key 为空,设为 null
  • Entry key 不为空,rehash,定位到某一槽位,若该槽位不为 null,向后遍历,知道遇到槽位上为 null 的

通过这种方式,使得原本槽位位置(因为槽位上已占据元素,不得不往后移动)与当前槽位位置之间的距离更近了(假设原本槽位位置当前槽位位置的中间位置的某一槽位上的 key 过期,调用 expungeStaleEntry后,减少两槽位之间的距离),可优化散列表的查询性能

rehash

情况四:槽位上为空,直接更新,之后,会进行扩容

tab[i] = new ThreadLocal.ThreadLocalMap.Entry(key, value);    // 情况四:槽位上为空,直接更新
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
  rehash();

t h r e s h o l d = 16 × 2 3 = 10 ( 小数部位舍去 ) 1 4 t h r e s h o l d = 10 4 = 2 t h r e s h o l d − 1 4 t h r e s h o l d = 10 − 2 = 8 threshold = 16 \times \frac{2}{3} = 10(小数部位舍去) \\ \frac{1}{4}threshold = \frac{10}{4} = 2 \\ threshold - \frac{1}{4}threshold = 10 -2 = 8 threshold=16×32=10(小数部位舍去)41threshold=410=2threshold41threshold=102=8

两倍扩容,然后进行 rehash

private void rehash() {
  expungeStaleEntries();    // 清除所有过期 entry

  // Use lower threshold for doubling to avoid hysteresis
  if (size >= threshold - threshold / 4)    // threshold = 16 * 2 / 3 = 10(小数部位舍去),threshold / 4 = 10 / 4 = 2,threshold - threshold / 4 = 8
    resize();
}

/**
 * Double the capacity of the table.
 */
private void resize() {
  ThreadLocal.ThreadLocalMap.Entry[] oldTab = table;
  int oldLen = oldTab.length;
  int newLen = oldLen * 2;
  ThreadLocal.ThreadLocalMap.Entry[] newTab = new ThreadLocal.ThreadLocalMap.Entry[newLen];
  int count = 0;

  for (int j = 0; j < oldLen; ++j) {
    ThreadLocal.ThreadLocalMap.Entry e = oldTab[j];
    if (e != null) {
      ThreadLocal<?> k = e.get();
      if (k == null) {
        e.value = null; // Help the GC
      } else {
        int h = k.threadLocalHashCode & (newLen - 1);
        while (newTab[h] != null)
          h = nextIndex(h, newLen);
        newTab[h] = e;
        count++;
      }
    }
  }

  setThreshold(newLen);
  size = count;
  table = newTab;
}

get 方法

public T get() {
  Thread t = Thread.currentThread();
  ThreadLocal.ThreadLocalMap map = getMap(t);
  if (map != null) {
    ThreadLocal.ThreadLocalMap.Entry e = map.getEntry(this);
    if (e != null) {
      @SuppressWarnings("unchecked")
      T result = (T)e.value;
      return result;
    }
  }
  return setInitialValue();    // map 为空,初始化
}

private ThreadLocal.ThreadLocalMap.Entry getEntry(ThreadLocal<?> key) {
  int i = key.threadLocalHashCode & (table.length - 1);
  ThreadLocal.ThreadLocalMap.Entry e = table[i];
  if (e != null && e.get() == key)
    return e;
  else
    return getEntryAfterMiss(key, i, e);  // 槽位上为空,向后查找
}

private ThreadLocal.ThreadLocalMap.Entry getEntryAfterMiss(ThreadLocal<?> key, int i, ThreadLocal.ThreadLocalMap.Entry e) {
  ThreadLocal.ThreadLocalMap.Entry[] tab = table;
  int len = tab.length;
  while (e != null) {
    ThreadLocal<?> k = e.get();
    if (k == key)
      return e;
    if (k == null)
      expungeStaleEntry(i);
    else
      i = nextIndex(i, len);
    e = tab[i];
  }
  return null;
}

private T setInitialValue() {
  T value = initialValue();
  Thread t = Thread.currentThread();
  ThreadLocal.ThreadLocalMap map = getMap(t);
  if (map != null)
    map.set(this, value);
  else
    createMap(t, value);
  return value;
}

参考资料

  • 万字图文深度解析ThreadLocal
  • Java 并发编程深度解析与实战(Mic)

你可能感兴趣的:(Java,java)