ThradLocal原理解析及SpringSecurity无法在子线程中获取上下文信息解决

ThreadLocal使用及其原理解析

一、前言


项目中使用到了SpringSecurity框架作为安全验证,但是却发现一个问题,即当在子线程中获取SecurityContextHolder中存储的对象时会报空指针异常,后来发现原来SpringSecurity默认是将对象信息存储在ThreadLocal类型的变量中,因此在子线程调用便会出现空指针异常,异常代码如下:

    public void insertFill(MetaObject metaObject) {
        log.info("创建时的自动插入策略生效");
        this.strictInsertFill(metaObject, SystemConfig.createTime, Date.class, new Date());
        if(loginUser == null){
            try {
                //此处会出现空指针异常
                loginUser = (AdminLoginUser) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
            }catch (Exception e){
                log.info("用户尚未登录,自动插入策略无需生效");
            }
        }
        if (loginUser != null && loginUser.getAdminUser() != null){
            this.strictInsertFill(metaObject, SystemConfig.createBy, String.class, loginUser.getAdminUser().getId());
        }
    }

SpringSecurity使用ThreadLocal存储上下文信息代码如下:

	public static SecurityContext getContext() {
		return strategy.getContext();
	}
	
 	//默认采用就是ThreadLocal进行存储
	private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>();

	public SecurityContext getContext() {
		SecurityContext ctx = contextHolder.get();

		if (ctx == null) {
			ctx = createEmptyContext();
			contextHolder.set(ctx);
		}

		return ctx;
	}

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

因此,为了解决这一问题,打算研究一下ThreadLocal使用及其原理

二、作用


1、基本概念

ThreadLocal可以解释成线程的局部变量,也就是说一个ThreadLocal的变量只有当前自身线程可以访问,别的线程都访问不了,那么自然就避免了线程竞争。因此,ThreadLocal提供了一种与众不同的线程安全方式,它不是在发生线程冲突时想办法解决冲突,而是彻底的避免了冲突的发生

2、基本使用

创建一个ThreadLocal对象:

    public static void main(String[] args) {
        ThreadLocal<String> localVar = new ThreadLocal<>();
        localVar.set("hello world!!!"); //设置值为hello world!!!
        String val = localVar.get(); //在当前线程取出对应的值
        System.out.println(val); //输出结果为hello world!!!
        
        //新开线程取ThreadLocal变量的值
        Thread thread = new Thread(() -> {
            System.out.println(localVar.get());//输出结果为null
        });
        thread.start();
    }

由于ThreadLocal里设置的值,只有当前线程自己看得见,这意味着你不可能通过其他线程为它初始化值。为了弥补这一点,ThreadLocal提供了一个withInitial()方法统一初始化所有线程的ThreadLocal的值,此时的值对所有线程可见:

        ThreadLocal<String> localVar = ThreadLocal.withInitial(() -> "hello world!!!");
        localVar.set("hello world!!!");
        String val = localVar.get();
        System.out.println(val); //输出结果为hello world!!!
        Thread thread = new Thread(() -> {
            System.out.println(localVar.get()); //输出结果也为hello world!!!
        });
        thread.start();

三、ThradLocal的实现原理


1、get()方法

ThreadLocal变量只在单个线程内可见,那它是如何做到的呢?我们先从最基本的get()方法说起:

    public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程绑定的ThreadLocalMap变量
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //ThreadLocalMap的key就是当前ThreadLocal对象实例
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                //从map里取出值
                T result = (T)e.value;
                return result;
            }
        }
        //当map为空或者对应的key不存在时进行初始化
        return setInitialValue();
    }

    private T setInitialValue() {
        //直接赋值null
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            //当map存在时,将null与对应的ThreadLocal对象实例进行绑定
            map.set(this, value);
        else
            //当map不存在时,为当前线程创建对应的ThreadLocalMap变量
            createMap(t, value);
        return value;
    }

    protected T initialValue() {
        return null;
    }

    void createMap(Thread t, T firstValue) {
        //为当前线程创建对应的ThreadLocalMap变量--threadLocals,所有的ThreadLocal对象实例及其绑定的对象存在threadLocals中
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        //ThreadLocalMap底层与HashMap有相似之处,也是通过数组的形式进行存储,每个数组中存的是内部类Entry
        table = new Entry[INITIAL_CAPACITY];
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
    }

因此ThreadLocal数据隔离的实现是因为ThreadLocal类操作的是Thread的成员变量threadLocals。每个线程Thread都有自己的threadLocals,从而互相不影响。

threadLocals这个成员变量的本质又是ThreadLocalMap类,它是ThreadLocal的内部类,下面我们研究一下这个内部类的数据结构:

static class ThreadLocalMap {
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
    //初始化容量
    private static final int INITIAL_CAPACITY = 16;
    //散列表
    private Entry[] table;
    //有效数量
    private int size = 0;
    //负载因子
    private int threshold;
    
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }
    //构造器
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY];
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
    }
}

ThreadLocalMap底层与HashMap有相似之处,也是通过数组的形式进行存储,每个数组中存的是内部类Entry,而且Entry继承了WeakReference类,即ThreadLocalMapkey是弱引用,value是强引用

key设为弱引用的好处是,如果这个变量不再被其他对象使用时,可以自动回收这个ThreadLocal对象,避免可能的内存泄露(注意:Entry中的value依然是强引用)

因此其数据结构如下所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PFv3JSFJ-1665133197543)(E:\JavaResources\学习笔记\ThreadLocal数据结构.png)]

2、set()方法

ThreadLocalMapjava.util.HashMap两者对于Hash冲突的解决方式是不同的。

对于HashMap而言,其使用的是链表法来处理冲突:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jvJgtIVe-1665133197543)(E:\JavaResources\学习笔记\链表法.png)]

而对于ThreadLocalMap而言,其使用的是简单的线性探测法,如果发生了元素冲突,那么就使用下一个槽位存放:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vMh7jXoj-1665133197544)(E:\JavaResources\学习笔记\image-20221005163649149.png)]

因此,ThreadLocalset()过程如下:

        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            //根据hash值找到key对应的索引位置
            int i = key.threadLocalHashCode & (len-1);
            
            //判断是否发生hash冲突,如果冲突,则一直往下找,直到找到未冲突的索引位置
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                //如果key一样,将值进行替换
                if (k == key) {
                    e.value = value;
                    return;
                }
                //如果key为null,表示原来的key已经被回收了,那么进行清理,具体清理过程见下文
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //解决了hash冲突,将entry放到对应的位置
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

3、ThreadLocal中的内存泄漏问题

虽然ThreadLocalMap中key是弱引用,当不存在外部强引用的时候,就会自动被回收,但是Entry中的value依然是强引用。这个value的引用链条如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FFzh8swb-1665133197544)(E:\JavaResources\学习笔记\引用链条.png)]

可以看到,只有当Thread被回收时,这个value才有被回收的机会,否则,只要线程不退出,value总是会存在一个强引用。但是,要求每个Thread都会退出,是一个极其苛刻的要求,对于线程池来说,大部分线程会一直存在在系统的整个生命周期内,那样的话,就会造成value对象出现泄漏的可能。处理的方法是,在ThreadLocalMap进行set(),get(),remove()的时候,都会进行清理

get()方法举例,从上文中可以看到get()方法中会调用 ThreadLocalMap.Entry e = map.getEntry(this);,如下所示:

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                //检测到key为null,进行清理
                return getEntryAfterMiss(key, i, e);
        }

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    //清理key为null的元素
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }


        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    //将其value置为null,从而让gc进行自动回收
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

真正用来回收value的是expungeStaleEntry()方法,在remove()set()方法中,都会直接或者间接调用到这个方法进行value的清理。

从这里可以看到,ThreadLocal为了避免内存泄露,也算是花了一番大心思。不仅使用了弱引用维护key,还会在每个操作上检查key是否被回收,进而再回收value

但是从中也可以看到,ThreadLocal并不能100%保证不发生内存泄漏。

比如,很不幸的,你的get()方法总是访问固定几个一直存在的ThreadLocal,那么清理动作就不会执行,如果你没有机会调用set()remove(),那么这个内存泄漏依然会发生。

因此,一个良好的习惯依然是:当你不需要这个ThreadLocal变量时,主动调用remove(),这样对整个系统是有好处的

四、InheritableThreadLocal的作用


1、基本使用

让我们回到本文开头,本文的主要目的除了探讨ThreadLocal外,还需要解决SpringSecurity中无法在子线程中获取在主线程中存储的对象,而SpringSecurity中是将对象存在了ThradlLocal中。即实际开发过程中,我们可能会遇到这么一种场景。主线程开了一个子线程,但是我们希望在子线程中可以访问主线程中的ThreadLocal对象,也就是说有些数据需要进行父子线程间的传递

此时便可以使用InheritableThreadLocal,顾名思义,这就是一个支持线程间父子继承的ThreadLocal

    public static void main(String[] args) {
        ThreadLocal<String> localVar = new InheritableThreadLocal<>();
        localVar.set("hello world!!!"); 
        String val = localVar.get();
        System.out.println(val);//输出为hello world!!!
        Thread thread = new Thread(() -> {
            System.out.println(localVar.get());//输出也为hello world!!!
        });
        thread.start();
    }

从而每个线程都可以访问到从父进程传递过来的一个数据。虽然InheritableThreadLocal看起来挺方便的,但是依然要注意以下几点:

  • 变量的传递是发生在线程创建的时候,如果不是新建线程,而是用了线程池里的线程,就不灵了

  • 如果采用的是线程池执行异步任务,并且还想在SpringSecurity中传播上下文,推荐的方法是改写线程池的配置

        @Bean(name = "asyncPoolTaskExecutor")
        public ThreadPoolTaskExecutor executor(){
            ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
            taskExecutor.setCorePoolSize(8);
            taskExecutor.setMaxPoolSize(16);
            taskExecutor.setQueueCapacity(50);
            taskExecutor.setKeepAliveSeconds(200);
            taskExecutor.setThreadNamePrefix("async-");
            taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.DiscardOldestPolicy());
            //主要是添加DelegatingSecurityContextRunnable,然后进行初始化initialize
            taskExecutor.setTaskDecorator(DelegatingSecurityContextRunnable::new);
            taskExecutor.initialize();
            return taskExecutor;
        }
    
    //参考链接:https://stackoverflow.com/questions/3467918/how-to-set-up-spring-security-securitycontextholder-strategy
    
  • 变量的赋值就是从主线程的map复制到子线程,它们的value是同一个对象,如果这个对象本身不是线程安全的,那么就会有线程安全问题

2、实现原理

InheritableThreadLocal 实际上是 ThreadLocal 的子类,我们来看下 InheritableThreadLocal 的定义:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    
    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
   
}

从源码中可以看出,InheritableThreadLocalThreadLocal主要的不同在于,getMap 方法的返回值变成了 inheritableThreadLocals 对象,但是在createMap 方法中,构建出来的 inheritableThreadLocals 还依然是 ThreadLocalMap 的对象。ThreadLocal 相比,主要是保存数据的对象从 threadLocals 变为 inheritableThreadLocals

这样的变化,对于前面的我们所说的 ThreadLocal 中的 get/set 并不影响,也就是 ThreadLocal 的特性依然不变

于是继续深挖,发现当使用了InheritableThreadLocal之后,最大的变化发生在Threadinit方法中,init方法会在new Thread()时自动调用,源码如下:

private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize, AccessControlContext acc,
                  boolean inheritThreadLocals) {
    //inheritThreadLocals在new Thread()时默认是true
    ...
    //找到父线程
    Thread parent = currentThread();
    ...
    ...
    //在创建子线程的时候,如果父线程存在 inheritableThreadLocals 变量且不为空,就调用 ThreadLocal.createInheritedMap 方法为子线程的 inheritableThreadLocals 变量赋值
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    ...
    ...
}

//将父线程的 inheritableThreadLocals 变量值赋值给子线程的 inheritableThreadLocals 变量。因此,在子线程中就可以访问到父线程 ThreadLocal 中的数据了
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
}

因此通过在Threadinit方法中将父线程的inheritableThreadLocals赋值给子线程的inheritableThreadLocals 变量。因此,在子线程中就可以访问到父线程ThreadLocal中的数据了

五、SpringSecurity中解决子线程无法访问主线程中所存储的数据

实际上 SpringSecurity中的SecurityContextHolder一共定义了三种存储策略进行上下文的存储,如下所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6KkjsUFq-1665133197544)(E:\JavaResources\学习笔记\SecurityContextHolder的存储策略.png)]

public class SecurityContextHolder {
 public static final String MODE_THREADLOCAL = "MODE_THREADLOCAL";
 public static final String MODE_INHERITABLETHREADLOCAL = "MODE_INHERITABLETHREADLOCAL";
 public static final String MODE_GLOBAL = "MODE_GLOBAL";
    ...
    ...
}

而第二种存储策略MODE_INHERITABLETHREADLOCAL就支持在子线程中获取当前登录用户信息,而MODE_INHERITABLETHREADLOCAL 的底层使用的就是 InheritableThreadLocal,因此只需自定义SpringSecurity的上下文存储策略,便可以在子线程中获取到当前存储信息

可以通过改写系统启动参数完成该项配置,在系统启动时加上如下参数:

-Dspring.security.strategy=MODE_INHERITABLETHREADLOCAL

你可能感兴趣的:(Java学习,java,spring,boot,架构)