【深入理解 ThreadLocal】

深入理解 ThreadLocal

  • 介绍
  • 源码分析
    • ThreadLocal 类图
    • set(T value)
    • T get()
    • 内存泄露
    • hash冲突解决
  • 总结

介绍

官方介绍:
此类提供线程局部变量。这些变量与它们的正常对应变量不同,因为每个访问一个变量的线程(通过其get或set方法)都有自己的、独立初始化的变量副本。ThreadLocal实例通常是希望将状态与线程关联的类中的私有静态字段(例如,用户ID或事务ID)。
通俗理解:
ThreadLocal 是用来做线程变量隔离的。对应公共变量,当出现在多线程情况下,容易出现数据修改混乱的情况,通常的做法是加锁(Synchronize和Lock)处理,但这样的做法会降低并发性能。还有一种做法是采用ThreadLocal。它是通过空间来换时间

源码分析

ThreadLocal 类图

ThreadLocal 的类图结构比较简单,它内部维护了一个ThreadLocalMap类,可以把ThreadLocalMap理解为一个map容器。于此同时,Thread对象里面也维护了一个ThreadLocal.ThreadLocalMap属性。

【深入理解 ThreadLocal】_第1张图片

set(T value)

首先通过当前线程拿到ThreadLocalMap对象(Thread 里面有ThreadLocal.ThreadLocalMap属性,ThreadLocalMap 可以理解为一个Map对象)。然后 用当前ThreadLocal 对象为Key,T Value 为value。存入map中。

  public void set(T value) {
        //拿到当前线程
        Thread t = Thread.currentThread();
        //通过当前线程拿到ThreadLocalMap 对象
        ThreadLocalMap map = getMap(t);
        if (map != null)
            //将值放入map中。key 为当前的ThreadLocal对象
            map.set(this, value);
        else
           //如果通过 当前线程拿到的ThreadLocalMap对象为空,
           //则构造一个新的ThreadLocalMap,然后将value值放入到map中
            createMap(t, value);
    }
  ThreadLocalMap getMap(Thread t) {
        //getMap 拿到的就是ThreadLocalMap对象
        return t.threadLocals;
    }
static class ThreadLocalMap {
        static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         */
        private int size = 0;

        /**
         * The next size value at which to resize.
         */
        private int threshold; // Default to 0

        /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * Increment i modulo len.
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * Decrement i modulo len.
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }
        .....
    }
    
 void createMap(Thread t, T firstValue) {
        //构造一个新的ThreadLocalMap对象。最终是将ThreadLocalMap对象赋值给了Thread的属性。
        // 所以我们取值的时候,都是通过当前对象来拿到ThreadLocalMap对象,
        //然后通过key 来获取value
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
           //设置数组大小为16
            table = new Entry[INITIAL_CAPACITY];
            //通过key 的hash值 然后和数组大小减1进行与操作获得下标
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            //构造一个entry对象赋值给数组
            table[i] = new Entry(firstKey, firstValue);
            //容量加一
            size = 1;
            //设置阈值 为初始值的2/3.
            setThreshold(INITIAL_CAPACITY);
        }

T get()

如果前面的set(T value) 方法看明白的话,那get () 方法也就是很简单了。通过当前线程,拿到它的属性ThreadLocal.ThreadLocalMap对象。然后以当前对象ThreadLocal 为key。来从map中获取value。

 public T get() {
        Thread t = Thread.currentThread();
        //通过当前线程来获取ThreadLocalMap对象
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //从map 中获取value.   key 为当前对象 ThreadLocal
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //如果通过当前对象获取的ThreadLocalMap对象为空。则构造一个新的map,并返回null.
        return setInitialValue();
    }
  private T setInitialValue() {
         //返回初始的null值
        T value = initialValue();
        Thread t = Thread.currentThread();
        //构造一个新的ThreadLocalMap对象
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        //返回null 值
        return value;
    }
 protected T initialValue() {
        return null;
    }

内存泄露

我们来看一段ThreadLocal的代码例子。

private static final ThreadLocal<UserInfo> userInfoThreadLocal = new ThreadLocal<>();

public Response handleRequest(UserInfo userInfo) {
  Response response = new Response();
  try {
    // 1.用户信息set到线程局部变量中
    userInfoThreadLocal.set(userInfo);
    doHandle();
  } finally {
    // 3.使用完移除掉
    userInfoThreadLocal.remove();
  }
  return response;
}
//业务逻辑处理
private void doHandle () {
  // 2.实际用的时候取出来
  UserInfo userInfo = userInfoThreadLocal.get();
  //查询用户资产
  queryUserAsset(userInfo);
}

面试题: threadLocal 会出现内存溢出问题?
要分析这个问题,我们可以看下ThreadLocal 内存中的引用关系
【深入理解 ThreadLocal】_第2张图片
首先ThreadLocal对象ThreadLocal引用 引用着,这是一个强引用。同时,entry 的key 也被ThreadLocal对象引用。此时的引用是个弱引用。

   static class ThreadLocalMap {

		//Entry 继承于 WeakReference,是个弱引用
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
			//Entry 的key 是ThreadLocal对象,正是 Entry 继承的弱引用对象。
			// 即entry 的key 也是一个弱引用对象。
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

所以正常情况下,当threadLocal 所在的作用域结束了,工作被清理了。ThreadLocal对象就会被回收(虽然ThreadLocal对象被key引用,但是弱引用,依然会被gc回收)。然后key 就指向了null。
但是我们在使用线程池的情况下,线程使用完毕,不会被回收。所以 从Thread引用–> Thread对象 --> ThreadLocalMap–>Entry 这条引用线一直存在。而且是强引用。这样 Entry中的 value对象就不能被回收了?
此时我们就可以手动回收value对象了。调用 ThreadLocal 的remove方法。

 public void remove() {
 		//获取当前线程的ThreadLocalMap对象
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
         //调用ThreadLocalMap对象的remove方法
             m.remove(this);
     }
  private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            //计算数组的下标
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    //进行数组的元素清理
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

其实ThreadLocal 里面的get/set 方法都会删除map 中key 为null的元素。这样可以尽最大努力避免内存泄露。

get 方法会调用 getEntryAfterMiss方法

     private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                  //如果key 是null. 删除该元素
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

set 方法会调用 set -->replaceStaleEntry–>cleanSomeSlots 方法

      private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }
				//key 为null.做删除操作
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

hash冲突解决

ThreadLocalMap 的冲突策略和HashMap不一样。HashMap 冲突后,是采用链表方式解决。ThreadLocalMap 如果发生冲突,则将元素放入下一个数组下标。如果超过了阈值,则进行扩容。

总结

ThreadLocal的逻辑不是特别复杂。但面试 问的特别多。希望本文能帮助到大家理解这块的逻辑。特别是涉及到内存泄露这块。

你可能感兴趣的:(多线程编程,java,ThreadLocal,并发,弱引用)