ThreadLocal源码解析

1 ThreadLocal简单介绍

ThreadLocal是JDK包提供的,它提供了线程本地变量。通俗一点说,就是你创建了一个ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的一个本地副本。当多个线程操作这个变量时,实际操作的是自己本地内存里面的变量,避免了线程安全问题。

2 ThreadLocal实现原理

ThreadLocal源码解析_第1张图片
由上图可得,Thread类里有一个threadLocals和inheritableThreadLocals,他们都是ThreadLocalMap类型的变量。 是一个定制化的HashMap。默认情况下,这两个变量都是null,只有当前线程第一次调用ThreadLocal的get或者set方法的时候,才会创建他们。其实每个线程的本地变量不是存放在ThreadLocal实例里面,而是放在调用线程的threadLocals里面。通俗一点,就是ThreadLocal类型的本地变量存放在具体的线程内存空间里面。ThreadLocal只是一个工具,它通过set方法把value放入调用线程的threadLocals里面并存起来,通过get方法,再从当前线程的threadLocals变量里面将其拿出来使用。如果调用线程一直不终止,那么这个本地变量会一直存放在调用线程的threadLocals变量里面。当不需要使用本地变量的时候,可以通过调用ThreadLocal的remove方法,从当前调用线程的threadLocals里面删除该本地变量。

2.1 void set(T value)

public void set(T value) {
	// 获取当前线程
    Thread t = Thread.currentThread();
    // 将当前线程作为search key,查找当前线程对应的ThreadLocalMap变量
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
    	// 第一次调用,就创建当前线程对应的ThreadLocalMap 
        createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

从源码可以看出, getMap(Thread t)的作用是获取线程自己的变量threadLocals,threadLocals被绑定到了线程的成员变量上。
如果getMap(Thread t)的返回值不为空,则把value设置到threadLocals中,也就是把当前变量值放入到当前线程的内存变量threadLocals中。threadLocals是一个HashMap的结构(下文会解析ThreadLocalMap),其中key就是当前ThreadLocal的实例对象的引用,value是通过set方法传递的值。
如果getMap(Thread t)的返回值为空,就说明是第一次调用set方法,这是创建当前线程的threadLocals变量。这也说明了threadLocals默认为null,只有当前线程第一次调用ThreadLocal的get或者set方法的时候,才会创建他们。

2.2 T get()

public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的threadLocals变量
    ThreadLocalMap map = getMap(t);
    // 如果threadLocals不为null,则返回对应本地变量的值
    if (map != null) {
    	// 通过ThreadLocal实例的引用,找到对应的entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // threadLocals为空,初始化当前线程的threadLocals成员变量
    return setInitialValue();
}

 private T setInitialValue() {
 	 // 初始化化为null 
     T value = initialValue();
     Thread t = Thread.currentThread();
     ThreadLocalMap map = getMap(t);
     // 如果当前线程的threadLocals不为null
     if (map != null)
         map.set(this, value);
     else
     	// 如果当前线程的threadLocals为null
         createMap(t, value);
     return value;
 }

 /**
  * 该方法可重写
  */
 protected T initialValue() {
    return null;
}

没啥好说的,首先后去当前线程实例,如果当前线程的threadLocals变量不为空,则直接返回当前线程绑定的本地变量,如果当前线程的threadLocals变量为空,进行初始化,创建threadLocals变量,并且返回null。

2.3 void remove()

 public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread());
     if (m != null)
         m.remove(this);
 }

如果当前线程的threadLocals变量不为空,则删除当前线程中指定ThreadLocal实例的本地变量。

3 ThreadLocalMap源码解析(主要)

ThreadLocaLMap是ThreadLocal的内部静态类。

static class ThreadLocalMap {

	/**
	 * 自定义一个Entry类,继承弱引用,保存ThreadLocal和Value之间的对应关系
	 * 用弱引用,是为了解决线程与ThreadLocal之间的强绑定关系
	 * 弊端:如果线程没有被回收,则GC便一直无法回收这部分内容
	 */
    static class Entry extends WeakReference<ThreadLocal<?>> {
       /** The value associated with this ThreadLocal. */
       Object value;

       Entry(ThreadLocal<?> k, Object v) {
           super(k);
           value = v;
       }
   }

   /**
     * 根据长度计算扩容的阀值
     * Set the resize threshold to maintain at worst a 2/3 load factor.
     */
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

    /**
     * Increment i modulo len.
     * 获取下一个索引, 超出长度,则返回0
     */
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

    /**
     * Decrement i modulo len.
     * 获取上一个索引,超出下限,则返回len-1
     * 结合nextIndex,就会发现Entry数组是个环形结构
     */
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
    }

	/**
	 * 构造函数
	 *
	 */
	ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
	       // 初始化table的大小为16
	       table = new Entry[INITIAL_CAPACITY];
	       // 通过hashcode & (长度-1)的位运算,确定键值对的位置
	       int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
	       // 创建一个新节点保存在table当中
	       table[i] = new Entry(firstKey, firstValue);
	       // 设置table中entry的数量
	       size = 1;
	       // 设置扩容阀值
	       setThreshold(INITIAL_CAPACITY);
	}
	
	/**
	 *  获取ThreadLocal的索引位置,通过下标索引获取内容
	 */
    private Entry getEntry(ThreadLocal<?> key) {
       // 通过hashcode确定下标
       int i = key.threadLocalHashCode & (table.length - 1);
       // 如果找到则直接返回
       Entry e = table[i];
       if (e != null && e.get() == key)
           return e;
       else
           // 找不到的话接着从i位置开始向后遍历
           // 基于线性探测法,是有可能在i之后的位置找到的
           return getEntryAfterMiss(key, i, e);
   }
	
   /**
    * 
    */ 	
   private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
      Entry[] tab = table;
      int len = tab.length;
	
	  // 循环向后遍历		
      while (e != null) {
      	  // 获取节点对应的k	
          ThreadLocal<?> k = e.get();
          // 相等则返回
          if (k == key)
              return e;
          // 如果为null,触发一次连续段清理
          if (k == null)
              expungeStaleEntry(i);
          else
          	  // 获取下一个下标接着进行判断	
              i = nextIndex(i, len);
          e = tab[i];
      }
      return null;
  }
  private void set(ThreadLocal<?> key, Object value) {
        // 新开一个引用指向table
        Entry[] tab = table;
        // 获取table的长度
        int len = tab.length;
        // 获取entry table当中的下标
        int i = key.threadLocalHashCode & (len-1);
		
		  /**
     	   * 从该下标开始循环遍历
           * 1、如遇相同key,则直接替换value
           * 2、如果该key已经被回收失效,则替换该失效的key
           */
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();

            if (k == key) {
                e.value = value;
                return;
            }
			// 如果 k 为null,则替换当前失效的k所在Entry节点
            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }

		 // 找到空的位置,创建Entry对象并插入
        tab[i] = new Entry(key, value);
        // table内元素size自增
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }	

	/**
	 * Remove the entry for key.
	 * 将ThreadLocal对象对应的Entry节点从table当中删除
 	 */
    private void remove(ThreadLocal<?> key) {
       Entry[] tab = table;
       int len = tab.length;
       int i = key.threadLocalHashCode & (len-1);
       for (Entry e = tab[i];
            e != null;
            e = tab[i = nextIndex(i, len)]) {
           if (e.get() == key) {
           	   // 将引用设置null,方便GC
               e.clear();
               // 从该位置开始进行一次连续段清理
               expungeStaleEntry(i);
               return;
           }
       }
   }

}

着重说一下expungeStaleEntry方法,为什么要进行连续段清理呢?
因为这些变量对于线程来说,是隔离的本地变量,并且使用的是弱引用,有可能在GC的时候就被回收了。而value不是弱引用,在key为null的时候,需要及时的被清理。

  • 如果有很多Entry节点已经被回收了,但是在table数组中还留着位置,这时候不清理就会浪费资源
  • 在清理节点的同时,可以将后续非空的Entry节点重新计算下标进行排放,这样子在get的时候就能快速定位资源,加快效率
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    // 获取长度
    int len = tab.length;

    // expunge entry at staleSlot  
    // 先将传过来的数组下标置为null
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    
    // table的size-1
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
     // 遍历删除指定节点所有后续节点当中,ThreadLocal被回收的节点
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        // 获取entry当中的key
        ThreadLocal<?> k = e.get();
        // 如果ThreadLocal为null,则将value以及数组下标所在位置设置null,方便GC
        // 并且size-1
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
        	// 如果不为null。重新计算数组的下标
            int h = k.threadLocalHashCode & (len - 1);
             
            // 如果是当前位置则遍历下一个
            // 如果不是当前位置,则重新从i开始找到下一个为null的坐标进行赋值
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

4 ThreadLocal的弊端

ThreadLocal的主要问题会产生脏数据和内存泄露。这两个问题通常是在线程池的线程中使用ThreadLocal引发的,因为线程池有线程复用和内存常驻的两个特点。

4.1 脏数据

线程池会复用Thread对象,那么与Thread绑定的类的静态属性ThreadLoca变量也会被重用。如果没有显示的调用remove方法清理与线程相关的ThreadLocal信息,那么倘若下一个线程不调用set方法设置初始值,就可能get到重用的线程信息,包括ThreadLocal所关联的线程对象的value值。

4.2 内存泄露

在源码注释中提示使用static关键字来修饰ThreadLocal。在此场景下,寄希望于ThreadLocal对象是去引用后,触发弱引用机制来回收Entry的Value就不现实了。如果不进行remove操作,这个线程执行完成后,ThreadLocal对象持有的Values对象是不会被释放的。

4.3 情景复现

笔者就不一一道来了,仅提供主要的部分的代码。用的Spring Boot创建的工程,配置了一个拦截器,一个Controller以及ThreadLocal对象的持有类。

/**
 * @Author: Heiky
 * @Date: 2020/4/16 13:49
 * @Description: 存放用户id 模拟ThreadLocal内存泄露
 */
public class UserContextHolder {

    private static ThreadLocal<Integer> contextHolder = ThreadLocal.withInitial(() -> null);

    /**
     * 存放用户id
     *
     * @param key
     */
    public static void set(Integer key) {
        contextHolder.set(key);
    }

    /**
     * 获取用户id
     *
     * @return
     */
    public static Integer get() {
        return contextHolder.get();
    }

    /**
     * 重置threadLocal
     */
    public static void clear() {
        contextHolder.remove();
    }
}

/**
 * @Author: Heiky
 * @Date: 2020/4/16 11:09
 * @Description:
 */

@RequestMapping("/threadlocal")
@RestController
public class ThreadLocalTestController {


    @GetMapping("/wrong")
    public Map wrong(@RequestParam("userId") Integer userId) {
        //设置用户信息之前先查询一次ThreadLocal中的用户信息
        String before = Thread.currentThread().getName() + ":" + UserContextHolder.get();
        //设置用户信息到ThreadLocal
        UserContextHolder.set(userId);
        //设置用户信息之后再查询一次ThreadLocal中的用户信息
        String after = Thread.currentThread().getName() + ":" + UserContextHolder.get();
        //汇总输出两次查询结果
        Map result = new HashMap();
        result.put("before", before);
        result.put("after", after);
        return result;
    }

}

设置tomcat的最大线程数,设置为1。

server:
  port: 8080
  tomcat:
    uri-encoding: utf-8
    # 最大并发数
    max-threads: 1

第一次请求的结果,此时传递userId=7
ThreadLocal源码解析_第2张图片
第二次请求的结果,此时传递userId=7,此时就不对了,获取了上个请求留下信息,因为没有及时remove。
ThreadLocal源码解析_第3张图片
如果想解决这个问题,只需加个拦截器,将ThreadLocal里面的信息进行remove。

/**
 * @Author: Heiky
 * @Date: 2020/4/16 13:47
 * @Description:
 */
@Component
public class ThreadLocalTestInterceptor implements HandlerInterceptor {

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, 
    										Object handler, Exception ex) throws Exception {
        UserContextHolder.clear();
    }
}

/**
 * @Author: Heiky
 * @Date: 2020/3/23 19:56
 * @Description:
 */

@Configuration
public class WebMvcConf extends WebMvcConfigurationSupport {

    @Autowired
    private ThreadLocalTestInterceptor threadLocalTestInterceptor;

    @Override
    protected void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(threadLocalTestInterceptor).addPathPatterns("/threadlocal/wrong");
        super.addInterceptors(registry);
    }
}

第一次查询结果
ThreadLocal源码解析_第4张图片
第二次查询结果
ThreadLocal源码解析_第5张图片
综上,要及时的在代码里面把ThreadLocal里面的信息remove。

5 ThreadLocal的总结

ThreadLocal源码解析_第6张图片

  1. 一个Thread有且仅有1个ThreadLocalMap对象
  2. 1个Entry对象的Key弱引用指向1个ThreadLocal对象
  3. 1个ThreadLocalMap对象可以存储多个Entry对象
  4. 1个ThreadLocal对象可以被多个线程所共享
  5. ThreadLocal对象不持有Value,Value是由线程的Entry对象持有

你可能感兴趣的:(Java基础,ThreadLocal)