深入理解ThreadLocal

1 定义

ThreadLocal是存储线程局部变量的容器。

它为每一个使用该变量的线程都提供了一个变量值的副本,是Java中一种较为特殊的线程绑定机制。

每一个线程都可以独立地改变自己的副本,而不会和其它线程的副本发生冲突。

2 原理分析

Java中,Thread类代表线程。

查看Thread源码,如下:

public class Thread implements Runnable {
    ......

    /**
     * 与此线程相关的ThreadLocal值。
     * 这个map由ThreadLocal类维护。
     */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * 与此线程相关的那些从父线程继承而来的ThreadLocal值。
     * 这个map由InheritableThreadLocal类维护。
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

可以看出,在Thread中是通过ThreadLocal.ThreadLocalMap来发挥ThreadLocal的功能的。

下面来看一下ThreadLocal的工作原理。

ThreadLocal提供了set(T value)get()方法,用来存取线程局部变量。

2.1 ThreadLocalset(T value)方法

查看ThreadLocalset(T value)方法的源码:

    /**
     * 设置当前线程中线程局部变量的值
     */
    public void set(T value) {
        // 获取当前线程对象
        Thread t = Thread.currentThread();
        // 获取ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        // 如果map存在,就将设置的value存入map中;如果map不存在,就创建一个map并写入值
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

getMap(t)源码如下:

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

createMap(t, value)的源码如下:

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

可以看出,getMap是将当前线程对象t传入,然后获取当前线程对象tthreadLocals的引用。因为每个线程Thread都有自己的threadLocals,所以getMap(t)返回的ThreadLocalMap是每个线程自己的。

每个线程中都有一个独立的ThreadLocalMap副本,它所存储的值,只能被当前线程读取和修改。

ThreadLocal类通过操作每一个线程特有的ThreadLocalMap副本,从而实现了变量访问在不同线程中的隔离。因为每个线程的变量都是自己的,完全不会有并发错误。

ThreadLocalMap存储的键值对中的键是this对象指向的ThreadLocal对象,而值就是所设置的对象。

2.2 ThreadLocalget()方法

查看ThreadLocalget()方法源码:

    /**
     * 设置当前线程中线程局部变量的值
     */
    public T get() {
        // 获取当前线程对象
        Thread t = Thread.currentThread();
        // 获取ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        // 1.如果map存在
        if (map != null) {
            // 从map中获取键值对,键为当前ThreadLocal对象
            ThreadLocalMap.Entry e = map.getEntry(this);
            // 如果键值对存在
            if (e != null) {
                // 获取键值对中的值,并返回
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        // 2.如果map不存在,设置初始化值并返回它
        return setInitialValue();
    }

getMap(t)源码如下:

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

setInitialValue()源码如下:

    /**
     * 设置初始化值
     */
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

    /**
     * 初始化
     */
    protected T initialValue() {
        return null;
    }

可以看出,和set(T value)同理,也是通过ThreadLocalMap来获取线程的局部变量的。这样就能保证获取到的值都是每个线程自己的副本,线程之间不会相互影响。

2.3 总结

实际使用的时候,ThreadLocal变量作为类中的实例域,会被所有的线程共享。

但是,每个线程获取ThreadLocal对象之后,通过set(T value)方法设置值的时候,首先是获取线程自己的ThreadLocalMap对象,然后将设置的值存入ThreadLocalMap中,键为这个线程获取的ThreadLocal对象,值为设置的value

所以,对于同一个ThreadLocal来说,在每个线程中的ThreadLocalMap中的键都是同一个对象;每个线程中的ThreadLocalMap可以有多个键值对,那么不同的键对应的就是不同的ThreadLocal实例域对象。

get()方法的原理和set(T value)一样的,也是通过通过ThreadLocalMap来实现线程隔离的。

3 需要注意的点

3.1 ThreadLocalMap不是Map接口的实现

ThreadLocalMap不是Map接口的实现,内部使用的是Entry[] table来保存键值对的。

   /**
    * The table, resized as necessary.
    * table.length MUST always be a power of two.
    */
    private Entry[] table;

并且,通过hash算法来做散列:

    // 计算数组索引
    int i = key.threadLocalHashCode & (len-1);

    ......

    // 生成threadLocal对象的hash值
    private final int threadLocalHashCode = nextHashCode();

    ......
    
    /**
     * 返回下一个hash值
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

3.2 ThreadLocal使用不当引发的内存泄漏问题

ThreadLocal可能存在内存泄漏问题的根源在于:

ThreadLocal中的key是弱引用的。

源码:

    /**
     * Entry继承WeakReference,
     * map中的键(ThreadLocal对象)是弱引用。
     * 注意,null键(entry.get() == null)表示这个键不再被引用,因此这个entry可以从数组中移除。
     */
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

3.2.1 为什么使用弱引用

要理解为什么ThreadLocalMap中需要使用WeakReference作为key类型,那么首先需要理解WeakReference的意义。

WeakReferenceJava语言规范中为了区别直接的对象引用(程序中通过构造函数声明出来的对象引用)而定义的另外一种引用关系。WeakReference标志性的特点是:不会影响到被引用对象的GC回收行为(即,只要对象被除了WeakReference对象之外所有的对象解除引用后,该对象便可以被GC回收),只不过在被引用对象回收之后,通过WeakReference获得被引用对象时程序会返回null

理解了WeakReference之后,ThreadLocalMap使用它的目的也相对清晰了:

ThreadLocal实例可以被GC回收时(该实例没有任何强引用了),系统可以通过弱引用检测到该ThreadLocal对应的Entry是否已经过期(根据reference.get() == null来判断,如果为true则表示过期,程序内部称为stale slots)来做一些自动清除工作,否则如果不清除的话容易产生内存无法释放的问题——value对应的对象即使不再使用,但由于被ThreadLocalMap所引用导致无法被GC回收。

3.2.2 内存泄漏问题

转载:https://www.jianshu.com/p/dde92ec37bd1

下面的图展示了ThreadLocalThreadLocalMapEntry之间的关系:

深入理解ThreadLocal_第1张图片

上图中,实线代表强引用,虚线代表弱引用。

  • 如果ThreadLocal实例对象的外部强引用ThreadLocalRef被置为nullthreadLocalRef == null)的话,ThreadLocal实例对象就没有一条引用链路可达,很显然在GC的时候势必会被回收。

  • 因此这个ThreadLocal实例对应的Entry就存在keynull的情况,程序是无法通过一个keynull去访问到该Entryvalue

  • 如果当前线程未被销毁。那么,就存在这样一条引用链:currentThreadRef -> currentThread -> threadLocalMap -> entry -> valueRef -> valueMemory,导致在垃圾回收的时进行可达性分析的时候,value可达从而不会被回收掉,但是该value永远不能被访问到了,这样导致了 **内存泄漏 ** 的问题。

  • 当然,如果线程执行结束后,栈被销毁,那么threadLocalRefcurrentThreadRef就会断掉。因此ThreadLocalThreadLocalMapEntry都会被回收掉,对应的value也会被回收,不会出现**内存泄漏 **。

  • 可是,在实际使用中我们大多数情况都会用线程池去维护我们的线程,线程在使用完之后并不会被销毁,而是返回到线程池中,这时候很可能出现ThreadLocal内存泄漏的问题,需要我们多加关注。

3.2.3 已经做出了哪些改进?

实际上,为了解决ThreadLocal潜在的内存泄漏的问题,Josh Bloch and Doug Lea大师已经做了一些改进。

ThreadLocalsetget方法中都有相应的处理。

下文为了叙述,针对key == nullentry,源码注释为stale entry,直译为“不新鲜的entry”,这里我就称之为“脏entry”。

查看ThreadLocalMapset方法:

    private void set(ThreadLocal<?> key, Object value) {

        // We don't use a fast path as with get() because it is at
        // least as common to use set() to create new entries as
        // it is to replace existing ones, in which case, a fast
        // path would fail more often than not.

        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);

        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();

            if (k == key) {
                e.value = value;
                return;
            }

            // 脏entry
            if (k == null) {
                // 替换这个脏entry
                replaceStaleEntry(key, value, i);
                return;
            }
        }

        // 插入新的entry
        tab[i] = new Entry(key, value);
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }

在该方法中针对 脏entry做了这样的处理:

  1. 如果当前table[i]!= null的话,说明hash冲突,就需要向后环形查找,若在查找过程中遇到脏entry就通过replaceStaleEntry(key, value, i)进行处理;
  2. 如果当前table[i] == null的话,说明这是新的entry,可以直接插入,但是插入后会调用cleanSomeSlots(i, sz)方法检测并清除脏entry

具体的源码分析,参见https://www.jianshu.com/p/dde92ec37bd1

3.3 ThreadLocal最佳实践

因为在线程池中使用ThreadLocal的时候,很可能引发内存泄漏的问题,所以:

在确定不再使用ThreadLocal的时候,请调用remove()方法删除数据。

下面是remove的源码:

    /**
     * 移除当前ThreadLocal对应的线程局部变量
     */
    public void remove() {
        // 拿到当前线程的ThreadLocalMap
        ThreadLocalMap m = getMap(Thread.currentThread());
        // 如果map不是null,调用map的remove方法删除当前这个ThreadLocal对应的Entry
        if (m != null)
            m.remove(this);
    }


    /**
     * 移除key对应的Entry
     */
    private void remove(ThreadLocal<?> key) {
        // 根据key计算数组中的索引位置i
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
            
        for (Entry e = tab[i];
            e != null;
            e = tab[i = nextIndex(i, len)]) {
            // 查找到与key对应的Entry
            if (e.get() == key) {
                // 通过clear方法将key的引用置为null,这个entry就变成了一个“脏Entry”
                e.clear();
                // 通过
                expungeStaleEntry(i);
                return;
            }
        }
    }


    /**
     * 删除“脏Entry”
     */
    private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // 删除entry
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

4 应用场景

最常见的ThreadLocal使用场景为 用来解决 数据库连接、Session管理等。

例如:

    private static ThreadLocal<Connection> connectionHolder = new ThreadLocal<Connection>() {
        public Connection initialValue() {
            return DriverManager.getConnection(DB_URL);
        }
    };

    public static Connection getConnection() {
        return connectionHolder.get();
    }

下面这段代码摘自 https://www.iteye.com/topic/103804:

    private static final ThreadLocal<Session> threadSession = new ThreadLocal<>();

    public static Session getSession() throws InfrastructureException {
        Session s = (Session) threadSession.get();
        try {
            if (s == null) {
                s = getSessionFactory().openSession();
                threadSession.set(s);
            }
        } catch (HibernateException ex) {
            throw new InfrastructureException(ex);
        }
        return s;
    }

你可能感兴趣的:(JAVA多线程编程)