白话ThreadLocal原理

ThreadLocal作用

对于Android程序员来说,很多人都是在学习消息机制时候了解到ThreadLocal这个东西的。那它有什么作用呢?官方文档大致是这么描述的:

  • ThreadLocal提供了线程局部变量
  • 每个线程都拥有自己的变量副本,可以通过ThreadLocal的set或者get方法去设置或者获取当前线程的变量,变量的初始化也是线程独立的(需要实现initialValue方法)
  • 一般而言ThreadLocal实例在类中被private static修饰
  • 当线程活着并且ThreadLocal实例能够访问到时,每个线程都会持有一个到它的变量的引用
  • 当一个线程死亡后,所有ThreadLocal实例给它提供的变量都会被gc回收(除非有其它的引用指向这些变量)
    上述中“变量”是指ThreadLocal的get方法获取的值
白话ThreadLocal原理_第1张图片
继续往下看.jpg

简单例子

先来看一个简单的使用例子吧:

public class ThreadId {

    private static final AtomicInteger nextId = new AtomicInteger(0);

    private static final ThreadLocal threadId = new ThreadLocal() {
        @Override
        protected Integer initialValue() {
            return nextId.get();
        }
    };

    public static int get() {
        return threadId.get();
    }
}

这也是官方文档上的例子,非常简单,就是通过在不同线程调用ThredId.get()可以获取唯一的线程Id。如果在调用ThreadLocal的get方法之前没有主动调用过set方法设置值的话,就会返回initialValue方法的返回值,并把这个值存储为当前线程的变量。

ThreadLocal到底是用来解决什么问题,适用什么场景呢,例子是看懂了,但好像还是没什么体会?ThreadLocal既然是提供变量的,我们不妨把我们见过的变量类型拿出来,做个对比

局部变量、成员变量 、 ThreadLocal、静态变量

变量类型 作用域 生命周期 线程共享性 作用
局部变量 方法(代码块)内部,其他方法(代码块)不能访问 方法(代码块)开始到结束 只存在于每个线程的工作内存,不能在线程中共享 解决变量在方法(代码块)内部的代码行之间的共享
成员变量 实例内 和实例相同 可在线程间共享 解决变量在实例方法之间的共享,否则方法之间只能靠参数传递变量
静态变量 类内部 和类的生命周期相同 可在多个线程间共享 解决变量在多个实例之间的共享
ThreadLocal存储的变量 整个线程 一般而言与线程的生命周期相同 不再多线程间共享 解决变量在单个线程中的共享问题,线程中处处可访问

ThreadLocal存储的变量本质上间接算是Thread的成员变量,ThreadLocal只是提供了一种对开发者透明的可以为每个线程存储同一维度成员变量的方式。

共享 or 隔离

网上有很多人持有如下的看法:
ThreadLocal为解决多线程程序的并发问题提供了一种新思路或者ThreadLocal是为了解决多线程访问资源时的共享问题。
个人认为这些都是错误的,ThreadLocal保存的变量是线程隔离的,与资源共享没有任何关系,也没有解决什么并发问题,这一点看了ThreadLocal的原理就会更加清楚。就好比上面的例子,每个线程应该有一个线程Id,这并不是什么并发问题啊。

同时他们会拿ThreadLocal与sychronized做对比,我们要清楚它们根本不是为了解决同一类问题设计的。sychronized是在牵涉到共享变量时候,要做到线程间的同步,保证并发中的原子性与内存可见性,典型的特征是多个线程会访问相同的变量。而ThreadLocal根本不是解决线程同步问题的,它的场景是A线程保存的变量只有A线程需要访问,而其它的线程并不需要访问,其他线程也只访问自己保存的变量。

原理

我们来一个开放性的问题,假如现在要给每个线程增加一个线程Id,并且Java的Thread类你能随便修改,你要怎么操作?非常简单吧,代码大概是这样

public class Thread{
      private int id;
      
      public void setId(int id){
          this.id=id;
      }
}

那好,现在题目变了,我们现在还得为每个线程保存一个Looper对象,那怎么办呢?再加一个Looper的字段不就好了,显然这种做法肯定是不具有扩展性的。那我们用一个容器类不就好了,很自然地就会想到Map,像下面这样

public class Thread{

      private Map map;
    
     public Map getMap(){
         if(map==null)
            map=new HashMap<>();
         return map;
     }
   
}

然后我们在代码里就可以通过如下代码来给Thread设置“成员变量”了

   Thread.currentThread().getMap().put("id",id);
   Thread.currentThread().getMap().put("looper",looper);

然后可以在该线程执行的任意地方,这样访问:

  Looper looper=(Looper) Thread.currentThread().getMap().get("looper");

看上去还不错,但是还是有些问题:

  • 保存和获取变量都要用到字符换key
  • 因为map中要保存各种值,因此泛型只得用Object,这样获取时候就需要强制转换(可用泛型方法解)
  • 当该变量没有作用时候,此时线程还没有执行完,需要手动设置该变量为空,否则会造成内存泄漏

为了不通过字符串访问,同时省去强制转换,我们封装一个类,就叫ThreadLocal吧,伪代码如下:

  public class ThreadLocal {

    public void set(T value) {
        Thread t = Thread.currentThread();
         Map map = t.getMap();
        if (map != null)
           //以自己为键
            map.put(this, value);
        else
            createMap(t, value);
    }


    public T get() {
        Thread t = Thread.currentThread();
        Map,T> map = t.getMap();
        if (map != null) {
            T e = map.get(this);
            return e;
        }
        return setInitialValue();
    }
}

没错,以上基本上就是ThreadLocal的整体设计了,只是线程中存储数据的Map是特意实现的ThreadLocal.ThreadLocalMap。

ThreadLocal与线程的关系如下:


白话ThreadLocal原理_第2张图片
ThreadLocal与线程的关系.png

如上图如所示,ThredLocal本身并不存储变量,只是向每个线程的threadLocals中存储键值对。ThreadLocal横跨线程,提供一种类似切面的概念,这种切面是作用在线程上的。

我们对ThreadLocal已经有一个整体的认识了,接下来我们大致看一下源码

源码分析

TheadLocal

   public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

set方法通过Thread.currentThread方法获取当前线程,然后调用getMap方法获取线程的threadLocals字段,并往ThreadLocalMap中放入键值对,其中键为ThreadLocal实例自己。

 ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
   }

接着看get方法:

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

很清晰,其中值得注意的是最后一行的setInitialValue方法,这个方法在我们没有调用过set方法时候调用。

private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

setInitialValue方法会获取initialValue的返回值并把它放进当前线程的threadLocals中。默认情况下initialValue返回null,我们可以实现这个方法来对变量进行初始化,就像上面TheadId的例子一样。

remove方法,从当前线程的ThreadLocalMap中移除元素。

public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

TheadLocalMap

看ThreadLocalMap的代码我们主要是关注以下两个方面:

  1. 散列表的一般设计问题。包括散列函数,散列冲突问题解决,负载因子,再散列等。
  2. 内存泄漏的相关处理。一般而言ThreadLocal 引用使用private static修饰,但是假设某种情况下我们真的不再需要使用它了,手动把引用置空。上面我们知道TreadLocal本身作为键存储在TheadLocalMap中,而ThreadLocalMap又被Thread引用,那线程没结束的情况下ThreadLocal能被回收吗?

散列函数
先来理一下散列函数吧,我们在之后的代码中会看到ThreadLocalMap通过int i = key.threadLocalHashCode & (len-1);决定元素的位置,其中表大小len为2的幂,因此这里的&操作相当于取模。另外我们关注的是threadLocalHashCode的取值。

  private final int threadLocalHashCode = nextHashCode();
 private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
  private static AtomicInteger nextHashCode =
        new AtomicInteger();
   private static final int HASH_INCREMENT = 0x61c88647;

这里很有意思,每个ThreadLocal实例的threadLocalHashCode是在之前ThreadLocal实例的threadLocalHashCode上加 0x61c88647,为什么偏偏要加这么个数呢?
这个魔数的选取与斐波那契散列有关以及黄金分割法有关,具体不是很清楚。它的作用是这样产生的值与2的幂取模后能在散列表中均匀分布,即便扩容也是如此。看下面一段代码:

  public class MagicHashCode {
      //ThreadLocal中定义的魔数
     private static final int HASH_INCREMENT = 0x61c88647;
     
     public static void main(String[] args) {
         hashCode(16);//初始化16
         hashCode(32);//2倍扩容
         hashCode(64);
     }
 
    private static void hashCode(int length){         
        int hashCode = 0; 
         for(int i=0;i

输出结果为:

7 14 5 12 3 10 1 8 15 6 13 4 11 2 9 0   //容量为16时
7 14 21 28 3 10 17 24 31 6 13 20 27 2 9 16 23 30 5 12 19 26 1 8 15 22 29 4 11 18 25 0  //容量为32时
7 14 21 28 35 42 49 56 63 6 13 20 27 34 41 48 55 62 5 12 19 26 33 40 47 54 61 4 11 18 25 32 39 46 53 60 3 10 17 24 31 38 45 52 59 2 9 16 23 30 37 44 51 58 1 8 15 22 29 36 43 50 57 0  //容量为64时

因为ThreadLocalMap使用线性探测法解决冲突(下文会看到),均匀分布的好处在于发生了冲突也能很快找到空的slot,提高效率。

瞄一眼成员变量:

       /**
         * 初始容量,必须是2的幂。这样的话,方便把取模运算转化为与运算, 
         * 效率高
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 容纳Entry元素,长度必须是2的幂
         */
        private Entry[] table;

        /**
         * table中的元素个数.
         */
        private int size = 0;

        /**
         * table里的元素达到这个值就需要扩容了
         * 其实是有个装载因子的概念的
         */
        private int threshold; // Default to 0

构造函数:

  ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY];
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
  }

firstKey和firstValue就是Map存放的第一个键值对喽。其中firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)很关键,就是当容量为2的幂时候,这相当于一个取模操作。然后把Entry存储到数组的第i个位置,设置扩容的阈值。

private void setThreshold(int len) {
          threshold = len * 2 / 3;
 }

这说明当数组里的元素容量达到2/3时候就要扩容,也就是装载因子是2/3。
接下来我们来看下Entry

 static class Entry extends WeakReference> {
            Object value;
            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
        }

就这么点东西,这个Entry只是与HashMap不同,只是个普通的键值对,没有链表结构相关的东西。另外Entry只持有对键,也就是ThreadLocal的弱引用,那么我们上面的第二个问题算是有答案了。当没有其他强引用指向ThreadLocal的时候,它其实是会被回收的。但是这有引出了另外一个问题,那Entry呢?当键都为空的时候这个Entry也是没有什么作用啊,也应该被回收啊。不慌,我们接着往下看。

set方法:

 private void set(ThreadLocal key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
           //如果冲突的话,进入该循环,向后探测
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal k = e.get();
               //判断键是否相等,相等的话只要更新值就好了
                if (k == key) {
                    e.value = value;
                    return;
                }
         
                if (k == null) {
                   //该Entry对应的ThreadLocal已经被回收,执行replaceStaleEntry并返回
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //进行启发式清理,如果没有清理任何元素并且表的大小超过了阈值,需要扩容并重哈希
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

我们发现如果发生冲突的话,整体逻辑会一直调用nextIndex方法去探测下一个位置,直到找到没有元素的位置,逻辑上整个表是一个环形。下面是nextIndex的代码,就是加1而已。

private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
}

线性探测的过程中,有一种情况是需要清理对应Entry的,也就是Entry的key为null,我们上面讨论过这种情况下的Entry是无意义的。因此调用
replaceStaleEntry(key, value, i);在看replaceStaleEntry(key, value, i)我们先明确几个问题。采用线性探测发解决冲突,在插入过程中产生冲突的元素之前一定是没有空的slot的。这样在也确保在查找过程,查找到空的slot就可以停止啦。但是假如我们删除了一个元素,就会破坏这种情况,这时需要对表中删除的元素后面的元素进行再散列,以便填上空隙。

空slot:即该位置没有元素
无效slot:该位置有元素,但key为null

replaceStaleEntry除了将value放入合适的位置之外,还会在前后连个空的slot之间做一次清理expungeStaleEntry,清理掉无效slot。

private void replaceStaleEntry(ThreadLocal key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 向前扫描到一个空的slot为止,找到离这个空slot最近的无效slot,记录为slotToExpunge
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len)) {
        if (e.get() == null) {
            slotToExpunge = i;
        }
    }

    // 向后遍历table
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal k = e.get();

        // 找到了key,将其与无效slot交换
        if (k == key) {
            // 更新对应slot的value值
            e.value = value;
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
            //如果之前还没有探测到过其他无效的slot
            if (slotToExpunge == staleSlot) {
                slotToExpunge = i;
            }
            // 从slotToExpunge开始做一次连续段的清理,再做一次启发式清理
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 如果当前的slot已经无效,并且向前扫描过程中没有无效slot,则更新slotToExpunge为当前位置
        if (k == null && slotToExpunge == staleSlot) {
            slotToExpunge = i;
        }
    }

    // 如果key之前在table中不存在,则放在staleSlot位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 在探测过程中如果发现任何其他无效slot,连续段清理后做启发式清理
    if (slotToExpunge != staleSlot) {
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }
}

expungeStaleEntry主要是清除连续段之前无效的slot,然后对元素进行再散列。返回下一个空的slot位置。

 private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // 删除 staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                   //对元素进行再散列
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

启发式地清理:
i对应是非无效slot(slot为空或者有效)
n是用于控制控制扫描次数
正常情况下如果log n次扫描没有发现无效slot,函数就结束了。
但是如果发现了无效的slot,将n置为table的长度len,做一次连续段的清理,再从下一个空的slot开始继续扫描。
这个函数有两处地方会被调用,一处是插入的时候可能会被调用,另外个是在替换无效slot的时候可能会被调用, 区别是前者传入的n为实际元素个数,后者为table的总容量。

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        // i在任何情况下自己都不会是一个无效slot,所以从下一个开始判断
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 扩大扫描控制因子
            n = len;
            removed = true;
            // 清理一个连续段
            i = expungeStaleEntry(i);
        }
    } while ((n >>>= 1) != 0);
    return removed;
}

接着看set函数,如果循环过程中没有返回,找到合适的位置,插入元素,表的size增加1。这个时候会做一次启发式清理,如果启发式清理没有清理掉任何无效元素,判断清理前表的大小大于阈值threshold的话,正常就要进行扩容了,但是表中可能存在无效元素,先把它们清除掉,然后再判断。

private void rehash() {
    // 全量清理
    expungeStaleEntries();
    //因为做了一次清理,所以size可能会变小,这里的实现是调低阈值来判断是否需要扩容。 threshold默认为len*2/3,所以这里的threshold - threshold / 4相当于len/2。
    if (size >= threshold - threshold / 4) {
        resize();
    }
}

作用即清除所有无效slot

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null) {
            expungeStaleEntry(j);
        }
    }
}

保证table的容量len为2的幂,扩容时候要扩大2倍

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal k = e.get();
            if (k == null) {
                e.value = null; 
            } else {
                // 扩容后要重新放置元素
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null) {
                    h = nextIndex(h, newLen);
                }
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

get方法:

private Entry getEntry(ThreadLocal key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 对应的entry存在且key未被回收
    if (e != null && e.get() == key) {
        return e;
    } else {
        // 继续往后查找
        return getEntryAfterMiss(key, i, e);
    }
}

private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    // 不断向后探测直到遇到空entry
    while (e != null) {
        ThreadLocal k = e.get();
        // 找到
        if (k == key) {
            return e;
        }
        if (k == null) {
            // 该entry对应的ThreadLocal实例已经被回收,调用expungeStaleEntry来清理无效的entry
            expungeStaleEntry(i);
        } else {
            // 下一个位置
            i = nextIndex(i, len);
        }
        e = tab[i];
    }
    return null;
}

remove方法,比较简单,在table中找key,如果找到了断开弱引用,做一次连续段清理。

private void remove(ThreadLocal key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len - 1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            //断开弱引用
            e.clear();
            // 连续段清理
            expungeStaleEntry(i);
            return;
        }
    }
}

ThreadLocal与内存泄漏

从上文我们知道当调用ThreadLocalMap的set或者getEntry方法时候,有很大概率会去自动清除掉key为null的Entry,这样就可以断开value的强引用,使对象被回收。但是如果如果我们之后再也没有在该线程操作过任何ThreadLocal实例的set或者get方法,那么就只能等线程死亡才能回收无效value。因此当我们不需要用ThreadLocal的变量时候,显示调用ThreadLocal的remove方法是一种好的习惯。

小结

  • ThredLocal为每个线程保存一个自己的变量,但其实ThreadLocal本身并不存储变量,变量存储在线程自己的实例变量ThreadLocal.ThreadLocalMap threadLocals
  • ThreadLocal的设计并不是为了解决并发问题,而是解决一个变量在线程内部的共享问题,在线程内部处处可以访问
  • 因为每个线程都只会访问自己ThreadLocalMap 保存的变量,所以不存在线程安全问题

你可能感兴趣的:(白话ThreadLocal原理)