遇到线程池InheritableThreadLocal就废了,该怎么办?

王二北原创,转载请标明出处:来自王二北

一、从项目中遇到的问题说起:

最近项目有一个需求,如下图所示:
上游有A/B/C三个服务,通过Dubbo调用中间的Proxy服务,Proxy最终Dubbo Rpc调用不同的Target服务。
A/B/C三个服务中,共有N中业务场景,Target服务中共提供了50+业务接口,现在需要统计上游ABC中每种业务调用Target接口情况。

遇到线程池InheritableThreadLocal就废了,该怎么办?_第1张图片
image

方案如下:

1、为了减少对接口的修改,上游A/B/C在调用Proxy服务时,通过Dubbo上下文传递业务字段buzSource,Proxy通过上下文获取buzSource内容。

2、然后在Proxy添加一个group=Consumer的DubboConsumerFilter,在这个DubboFilter中将buzSource和调用Target接口的对应关系收集起来,通过Mq发给收集系统。

3、我们都知道,Dubbo上下文其实就是将数据存放到一个ThreadLocal中,客户端掉用服务端时,就是将上下文数据存放到请求数据的头信息attachment,服务端接收到请求数据后,再将数据从请求头中取出,存放到当前的Dubbo RpcContext的上下文中。

在一次Dubbo调用结束后,其上下文信息就会被内置的ConsumerConextFilter清空(如下图所示)。

遇到线程池InheritableThreadLocal就废了,该怎么办?_第2张图片
image

上游服务(A/B/C)的一次业务调用,对于Proxy服务来说,可能会涉及到调用下游Target的多个RPC接口。

因此buzSource字段传递到Target服务中后不能继续存放到dubbo上下文中,因为第一次rpc调用这个字段就被清了,后面的RPC调用就取不到了。
因此,这个buzSource字段,如果要在Proxy中的一个处理逻辑中一直存在,那自然就需要放在ThreadLocal中了。
在Proxy服务中添加了一个GROUP=Provider的DubboProviderFilter,用于请求到来后,将buzSource从RpcContext取出并放到一个新的ThreadLocal中,在请求处理完后,清理调用当前线程对应的buzSource信息。
整体结构图如下所示:

遇到线程池InheritableThreadLocal就废了,该怎么办?_第3张图片
image

遇到的问题:

在Proxy中,有些业务是多线程并发调用的,也就说当上游服务A的请求到达Proxy后,请求后被拆分成多个子请求,放到线程池中去调用下游的Target服务。这个时候,线程池中的线程就无法获取业务主线程上下文中的buzSource了。

了解ThreadLocal的童鞋都知道,ThreadLocal的实现原理,就是因为Thread类中有一个ThreadLocalMap类型的属性threadLocals, key是ThreadLocal, value是你要存放的上下文的值。也就是说线程上下文是和线程一一绑定的,自然其他线程就无法获取了。


image

不想侵入业务代码,在每个业务调用的地方修改传入buzSource,那样太low了,怎么办,突然想到了InheritableThreadLocal,这个类是ThreadLocal的子类,作用是父线程存放到InheritableThreadLocal中的数据,子线程也可以拿到。简单做了个小例子:

private static final InheritableThreadLocal local = new InheritableThreadLocal();
public void test(){
    local.set("111111");
    new Thread(){
        public void run(){
            System.out.println(local.get());
        }
    }.start();
}
子线程输出结果为:111111

例子运行没有问题,于是兴冲冲的将原来存放buzSource的ThreadLocal改为了InheritableThreadLocal,然后在DubboConsumerFilter中取获取buzSource,然而,仍然获取不到!遇到线程池InheritableThreadLocal就废了,为什么?要弄明白原因,先来看看InheritableThreadLocal的实现原理.

二、InheritableThreadLocal实现原理

为什么线程池的线程无法获取业务线程存放到InheritableThreadLocal的数据?这就得从InheritableThreadLocal的原理说起了。

InheritableThreadLocal 继承自ThreadLocal,重写了其中crateMap方法和getMap方法。重写这两个方法的目的是使得所有线程通过InheritableThreadLocal设置的上下文信息,都保存在其对应的inheritableThreadLocals属性中。这一点和ThreadLocal不同,ThreadLocal是保存在Thread的threadLocals属性中。

下面是Thread类汇中threadLocals属性和inheritableThreadLocals两个属性,当调用ThreadLocal类型的上下文对象设置参数时,设置的就是其threadLocals属性对应的Map的kv值,当调用InheritableThreadLocal类型的上下文对象设置参数时,就是设置其inheritableThreadLocals属性的kv值:

 /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

下面是InheritableThreadLocal重写的crateMap方法和getMap方法,正是通过这两个方法,改变了ThreadLocal中要设置和获取Thread中哪个属性的方法。

public class InheritableThreadLocal extends ThreadLocal {
    /**
     * Computes the child's initial value for this inheritable thread-local
     * variable as a function of the parent's value at the time the child
     * thread is created.  This method is called from within the parent
     * thread before the child is started.
     * 

* This method merely returns its input argument, and should be overridden * if a different behavior is desired. * * @param parentValue the parent thread's value * @return the child thread's initial value */ protected T childValue(T parentValue) { return parentValue; } /** * Get the map associated with a ThreadLocal. * * @param t the current thread */ ThreadLocalMap getMap(Thread t) { return t.inheritableThreadLocals; } /** * Create the map associated with a ThreadLocal. * * @param t the current thread * @param firstValue value for the initial entry of the table. */ void createMap(Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue); } }

那么子线程是如何能获取到父线程保存到InheritableThreadLocal类型上下文中数据的呢?
原来是在创建Thread对象时,会判断父线程中inheritableThreadLocals是否不为空,如果不为空,则会将父线程中inheritableThreadLocals中的数据复制到自己的inheritableThreadLocals中。这样就实现了父线程和子线程的上下文传递。

public Thread(ThreadGroup group, Runnable target, String name,
              long stackSize) {
    init(group, target, name, stackSize);
}

private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize) {
    init(g, target, name, stackSize, null, true);
}

private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize, AccessControlContext acc,
                  boolean inheritThreadLocals) {
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }

    this.name = name;
    ............
    ............
    // 此处会初始化Thread对应的inheritThreadLocals
    // 如果父线程的inheritThreadLocals不为空,则会复制父线程的inheritThreadLocals
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        // 根据父线程的inheritableThreadLocals对应的parentMap创建子线程的inheritableThreadLocals
        this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    /* Stash the specified stack size in case the VM cares */
    this.stackSize = stackSize;

    /* Set thread ID */
    tid = nextThreadID();
}
// 根据父线程的inheritableThreadLocals对应的parentMap创建子线程的inheritableThreadLocals 
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}


private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];
    // 将父Thread中的inheritThreadLocals(ThreadLocalMap)对应的的所有值都复制给子类的inheritThreadLocals
    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
            ThreadLocal key = (ThreadLocal) e.get();
            if (key != null) {
                // 注意此处的childValue()方法对应InheritableThreadLocal中重写的childValue方法
                // 在前面InheritableThreadLocal中,是直接返回了传入的值
                // 也就是直接将这个值的引用(引用类型时是对象的引用,普通类型或者String时,传的是值本身)传给了子线程的Entry
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}

 
 

通过上面对InheritableThreadLocal的简单分析,我们可以知道,在创建Thread时,才会将父线程中的inheritableThreadLocals复制给新创建Thread的inheritableThreadLocals。

但是在线程池中,业务线程只是将任务对象(实现了Runnable或者Callable的对象)加入到任务队列中,并不是去创建线程池中的线程,因此线程池中线程也就获取不到业务线程中的上下文信息。

那么在线程池的场景下就没有方法解决了吗,只能去修改业务代码,在提交任务对象时,手工传入buzSource吗?观察了一下,项目中,前人写的代码,传个traceId啥的,都是这么修改业务代码传入traceId之类参数,这样很不优雅,如果以后有其他类似的字段,还是需要大量的进行修改,有没有什么办法可以解决这个问题?

三、解决遇到线程池InheritableThreadLocal就废了的问题

其实仔细想想,子线程之所以能获得父线程放到InheritableThreadLocal的数据,是因为在创建子线程时,复制了父线程的inheritableThreadLocals属性,触发复制的时机是创建子线程的时候。

在线程池场景下,是提交任务,既然要提交任务,那么就要创建任务,那么能否在创建任务的时候,做做文章呢?

下面是我的实现:

1、定义一个InheritableTask抽象类,这个类实现了Runaable接口,并定义了一个runTask抽象方法,当开发者需要面对线程池获取InheritableThreadLocal值的场景时,提交的任务对象,只需要继承InheritableTask类,实现runTask方法即可。

2、在创建任务类时,也就是在InheritableTask构造函数中,通过反射获,获取到提交任务的业务线程的inheritableThreadLocals属性,然后复制一份,暂存到当前task的inheritableThreadLocalsObj属性中。

3、线程池线程在执行该任务时,其实就是去调用其run()方法,在执行run方法时,先将暂存的inheritableThreadLocalsObj属性,赋值给当前执行任务的线程,这样这个线程就可以得到提交任务的那个业务线程的inheritableThreadLocals属性值了。然后再去执行runTask(),就是真正的业务逻辑。最后,finally清理掉执行当前业务的线程的inheritableThreadLocals属性。

/**
 * @author 王二北
 * @description
 * @date 2019/8/21
 */
public abstract class InheritableTask implements Runnable {
    private Object inheritableThreadLocalsObj;
    public InheritableTask(){
       try{
           // 获取业务线程的中的inheritableThreadLocals属性值
           Thread currentThread = Thread.currentThread();
           Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
           inheritableThreadLocalsField.setAccessible(true);
           // 得到当前线程中的inheritableThreadLocals属性值
           Object threadLocalMapObj = inheritableThreadLocalsField.get(currentThread);
           if(threadLocalMapObj != null){
               // 调用ThreadLocal中的createInheritedMap方法,重新复制一个新的inheritableThreadLocals值
              Class threadLocalMapClazz =   inheritableThreadLocalsField.getType();
               Method method =  ThreadLocal.class.getDeclaredMethod("createInheritedMap",threadLocalMapClazz);
               method.setAccessible(true);
              // 创建一个新的ThreadLocalMap类型的inheritableThreadLocals
              Object newThreadLocalMap = method.invoke(ThreadLocal.class,threadLocalMapObj);
              // 将这个值暂存下来
              inheritableThreadLocalsObj = newThreadLocalMap;
           }
       }catch (Exception e){
           throw new IllegalStateException(e);
       }
    }

    /**
     * 搞个代理方法,这个方法中处理业务逻辑
     */
    public abstract void runTask();

    @Override
    public void run() {
        // 此处得到的是当前处理该业务的线程,也就是线程池中的线程
        Thread currentThread = Thread.currentThread();
        Field field = null;
        try {
            field  = Thread.class.getDeclaredField("inheritableThreadLocals");
            field.setAccessible(true);
            // 将暂存的值,赋值给currentThread
            if (inheritableThreadLocalsObj != null && field != null) {
                field.set(currentThread, inheritableThreadLocalsObj);
                inheritableThreadLocalsObj = null;
            }
            // 执行任务
            runTask();
        }catch (Exception e){
            throw new IllegalStateException(e);
        }finally {
            // 最后将线程中的InheritableThreadLocals设置为null
           try{
               field.set(currentThread,null);
           }catch (Exception e){
               throw new IllegalStateException(e);
           }
        }
    }
}

下面做个例子测试一下:

public class TestInheritableThreadLocal {
    private static InheritableThreadLocal local = new InheritableThreadLocal();
    private static ExecutorService es = Executors.newFixedThreadPool(5);
    public static void main(String[] args)throws Exception{
        for(int i =0;i<2;i++){
            final int ab = i;
            new Thread(){
                public void run(){
                    local.set("task____"+ab);
                    for(int i = 0;i<3;i++){
                        final  int a = i;
                        es.execute(new InheritableTask() {
                            @Override
                            public void runTask() {
                              System.out.println(Thread.currentThread().getName()+"_"+ ab+"get_"+ a +":" + local.get());
                            }
                        });
                    }
                }
            }.start();
        }
    }
)
运行结果,每个线程设置的值,都能被正确的获取到:
pool-1-thread-3_0get_1:task____0
pool-1-thread-4_1get_1:task____1
pool-1-thread-5_0get_2:task____0
pool-1-thread-1_1get_0:task____1
pool-1-thread-2_0get_0:task____0
pool-1-thread-3_1get_2:task____1

这样,就解决了在线程池场景下InheritableThreadLocal无效的问题。
然而,就这么完了吗?不,别忘了,反射是比较耗性能的。
一般优化反射性能的方式有两种,一种是使用缓存,一种是使用性能较高的反射工具,比如RefelectASM之类的。
我再使用的时候回发现RefelectAsm并不是特别好用,因为其不能反射获取private的字段,并且在获取inheritableThreadLocals字段时总是不成功,这里只展示一下使用缓存的实现:

/**
 * @author 王二北
 * @description
 * @date 2019/8/21
 */
public abstract class InheritableTaskWithCache implements Runnable {
    private Object obj;
    private static volatile Field inheritableThreadLocalsField;
    private static volatile Class threadLocalMapClazz;
    private static volatile Method createInheritedMapMethod;
    private static final Object accessLock = new Object();


    public InheritableTaskWithCache(){
       try{
           Thread currentThread = Thread.currentThread();
           Field field = getInheritableThreadLocalsField();
           // 得到当前线程中的inheritableThreadLocals熟悉值ThreadLocalMap, key是各种inheritableThreadLocal,value是值
           Object threadLocalMapObj = field.get(currentThread);
           if(threadLocalMapObj != null){
              Class threadLocalMapClazz = getThreadLocalMapClazz();
              Method method =  getCreateInheritedMapMethod(threadLocalMapClazz);
              // 创建一个新的ThreadLocalMap
              Object newThreadLocalMap = method.invoke(ThreadLocal.class,threadLocalMapObj);
              obj = newThreadLocalMap;
           }
       }catch (Exception e){
           throw new IllegalStateException(e);
       }
    }

    private Class getThreadLocalMapClazz(){
        if(inheritableThreadLocalsField == null){
            return null;
        }else {
            if(threadLocalMapClazz == null){
                synchronized (accessLock){
                    if(threadLocalMapClazz == null){
                        Class clazz = inheritableThreadLocalsField.getType();
                        threadLocalMapClazz = clazz;
                    }
                }
            }
        }
        return threadLocalMapClazz;
    }

    private Field getInheritableThreadLocalsField(){
        if(inheritableThreadLocalsField == null){
            synchronized (accessLock){
                if(inheritableThreadLocalsField == null){
                    try {
                        Field field = Thread.class.getDeclaredField("inheritableThreadLocals");
                        field.setAccessible(true);
                        inheritableThreadLocalsField = field;
                    }catch (Exception e){
                        throw new IllegalStateException(e);
                    }
                }
            }
        }
        return inheritableThreadLocalsField;
    }

    private Method getCreateInheritedMapMethod(Class threadLocalMapClazz){
        if(threadLocalMapClazz != null && createInheritedMapMethod == null){
            synchronized (accessLock){
                if(createInheritedMapMethod == null){
                    try {
                        Method method =  ThreadLocal.class.getDeclaredMethod("createInheritedMap",threadLocalMapClazz);
                        method.setAccessible(true);
                        createInheritedMapMethod = method;
                    }catch (Exception e){
                        throw new IllegalStateException(e);
                    }
                }
            }
        }
        return createInheritedMapMethod;
    }

    public abstract void runTask();

    @Override
    public void run() {
        boolean isSet = false;
        Thread currentThread = Thread.currentThread();
        Field field = getInheritableThreadLocalsField();
        try {
            if (obj != null && field != null) {
                field.set(currentThread, obj);
                obj = null;
                isSet = true;
            }
            // 执行任务
            runTask();
        }catch (Exception e){
            throw new IllegalStateException(e);
        }finally {
            // 最后将线程中的InheritableThreadLocals设置为null
           try{
               field.set(currentThread,null);
           }catch (Exception e){
               throw new IllegalStateException(e);
           }
        }
    }
}

下面对比一下使用缓存和不使用缓存的性能:

我使用的笔记本电脑,是I7,8核16G, 测试时,由于已经开了几个idea和一堆Chrome网页,cpu使用率已经达到60%左右。

首先是不使用缓存直接反射的Task实现,共执行了6次,每次都创建了3000w个InheritableTask对象,每次执行耗时如下:

reflect cost:2905ms
reflect cost:2165ms
reflect cost:2424ms
reflect cost:2756ms
reflect cost:2242ms
reflect cost:2487ms

然后是使用缓存反射字段的task实现,共执行了6次,每次都创建了3000w个InheritableTask对象,每次执行耗时如下:

cache cost:82ms
cache cost:70ms
cache cost:58ms
cache cost:94ms
cache cost:71ms
cache cost:60ms

可以发现使用cache的情况下,性能提高了30~40倍。总的来说,在使用缓存的情况下,性能还是不错的。

综上,通过实现一个抽象的InheritableTask解决了线程池场景下InheritableThreadLocal“失效”的问题。

总结:
1、InheritableThreadLocal在线程池下是无效的,原因是只有在创建Thread时才会去复制父线程存放在InheritableThreadLocal中的值,而线程池场景下,主业务线程仅仅是将提交任务到任务队列中。

2、如果需要解决这个问题,可以自定义一个RunTask类,使用反射加代理的方式来实现业务主线程存放在InheritableThreadLocal中值的间接复制。

你可能感兴趣的:(遇到线程池InheritableThreadLocal就废了,该怎么办?)