如何实现线程池之间的数据透传 ?

如何实现线程池之间的数据透传 ?

  • 引言
  • transmittable-thread-local
    • 概览
    • capture
      • 如何 capture
      • 如何保存捕获的数据
    • save 和 replay
    • restore
  • 小结


引言

当我们涉及到数据的全链路透传场景时,通常会将数据存储在线程的本地缓存中,如: 用户认证信息透传,链路追踪信息透传时;但是这里可能面临着数据在两个没有血缘关系的兄弟线程间透传的问题,这通常涉及到两个不同线程池之间数据的透传问题,如下图所示:

在这里插入图片描述
为了解决上面这个问题,最简单的思路就是手动在各个线程池的切换处添加捕获和回放逻辑,如下所示:

public class TTLMain {
    private static final Executor TTL_TEST_THREAD_POOL = Executors.newFixedThreadPool(1);

    public static void main(String[] args) {
        ThreadLocal<Integer> userId = new ThreadLocal<>();
        userId.set(1);
        // 1. 捕获当前线程的上下文信息
        Integer captured = userId.get();
        TTL_TEST_THREAD_POOL.execute(() -> {
            userId.set(2);
            // 2. 保存当前线程的上下文信息
            Integer backup = userId.get();
            // 3. 重放捕获的目标线程的上下文信息
            userId.set(captured);
            System.out.println("重放上下文后: 用户ID=" + userId.get());
            // 4. 恢复原先的线程上下文信息
            userId.set(backup);
            System.out.println("恢复上下文后: 用户ID="+userId.get());
        });
    }
}

其实不难看出整个处理过程分为四个阶段:

  1. capture : 捕获当前线程上下文信息
  2. save : 保存目标线程上下文信息
  3. replay : 重放当前线程的上下文信息到目标线程中
  4. restore : 恢复目标线程原先的上下文信息

整个过程属于一个模版流程,因此我们可以想办法把上面这段逻辑单独抽取固定下来,而非在各个切换处进行手动编码操作,因此这里引出了我今天想要介绍的现成工具类: transmittable-thread-local 。


transmittable-thread-local

transmittable-thread-local 是阿里开源的一个线程池间数据透传工具类,它的实现思路其实就是上面我讲的四个阶段,下面我们先来看看transmittable-thread-local具体是如何使用的吧:

public class TTLMain {
    private static ExecutorService TTL_TEST_THREAD_POOL = Executors.newFixedThreadPool(1);

    public static void main(String[] args) {
        demo1();
        demo2();
        demo3();
    }

    private static void demo1() {
        // 1. 修饰Runnable
        TransmittableThreadLocal<Integer> context = new TransmittableThreadLocal<>();
        context.set(1);

        TTL_TEST_THREAD_POOL.execute(TtlRunnable.get(() -> {
            System.out.println("修饰Runnable: " + context.get());
        }));
    }

    @SneakyThrows
    private static void demo2() {
        // 1. 修饰Callable
        TransmittableThreadLocal<Integer> context = new TransmittableThreadLocal<>();
        context.set(1);

        Future<Integer> future = TTL_TEST_THREAD_POOL.submit(TtlCallable.get(context::get));
        System.out.println("修饰Callable: "+future.get());
    }

    @SneakyThrows
    private static void demo3() {
        // 1. 修饰线程池
        TTL_TEST_THREAD_POOL = TtlExecutors.getTtlExecutorService(TTL_TEST_THREAD_POOL);
        TransmittableThreadLocal<Integer> context = new TransmittableThreadLocal<>();
        context.set(1);

        Future<Integer> future = TTL_TEST_THREAD_POOL.submit(context::get);
        System.out.println("修饰线程池: "+future.get());
    }
}

使用上比较简单,核心还是将capture,save,replay ,restore 四个阶段的逻辑以模版流程的形式安排到了TtlRunnable和TtlCallable中,下面我们就来看看transmittable-thread-local具体是如何实现的,以及我们能从中学到什么样设计技巧。


概览

如何实现线程池之间的数据透传 ?_第1张图片
TransmittableThreadLocal实现了InheritableThreadLocal,其可以确保数据能够在父子线程间进行透传,透传逻辑体现在Thread的构造函数中;而TransmittableThreadLocal要做的事情就是解决数据在不同线程池之间进行数据透传的问题,该问题解决思路就是本篇开头提到的思路,下面我将分四个阶段,依次来看看TransmittableThreadLocal是如何实现的。


capture

捕获阶段我们需要捕获当前线程使用到的所有TransmittableThreadLocal实例的数据,这一点如何做到 ? 以及我们用什么样的数据结构来保持捕获到的数据呢 ?

便于行文方便,下面会将TransmittableThreadLocal简写为TTL

如何 capture

如果当前线程本身没有向某个TTL实例中设置任何数据,那么其实没有必要捕获该实例内部的数据,因此这里只会在初次调用TTL的set方法时,才会向TTL内部的全局ThreadLocal注册表进行注册:

    // WeakHashMap 的key是被弱引用对象引用着的,并且value值允许为null,因此可以作为set集合使用
    // 这里就是当做Set集合使用的,因为这里我们只需要知道当前线程使用到的TTL有哪些
    private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<>();
                }

                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<>(parentValue);
                }
            };
    
    
    @Override
    public final void set(T value) {
        if (!disableIgnoreNullValueSemantics && value == null) {
            // may set null to remove value
            remove();
        } else {
            super.set(value);
            // 向全局ThreadLocal注册表进行注册
            addThisToHolder();
        }
    }

    private void addThisToHolder() {
        // 防止重复注册
        if (!holder.get().containsKey(this)) {
            // 进行注册,也就是添加到Set集合中
            holder.get().put((TransmittableThreadLocal<Object>) this, null);
        }
    }

holder的职责是负责记录当前线程使用到了哪些TTL,因此相对于TTL来说,Holder本身需要是全局静态的,同时又因为需要记录<线程,List< TTL >> 的映射关系,所以这里就有两种思路:

private static final Map<Thread,Set<TTL>> holder
or
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder

这两种思路的前者就不多说了,很直接的思路,后者则是将当前线程使用到的TTL保存到了当前线程本地空间中,这样就避免了holder集合多线程情况下争用问题的发生:
如何实现线程池之间的数据透传 ?_第2张图片
这里比较有趣的一点在于为什么holder保存当前线程使用到的TTL时,需要使用WeakHashMap这样一个弱引用Map呢 ?这一点和ThreadLocalMap中的Entry弱引用实现一致,那么这两者之间是否存在使用场景上的联系呢?

如何实现线程池之间的数据透传 ?_第3张图片
ThreadLocalMap使用场景下,Table中的key类型为ThreadLocal,val类型为我们通过ThreadLocal设置到当前线程本地空间中的值,如果ThreadLocal对象的引用变量和创建都位于某个方法内部,那么该方法执行完毕后,ThreadLocal理应被回收,如下所示:

    private static void threadLocalGCTest(){
        ThreadLocal<Integer> tl = new ThreadLocal<>();
        tl.set(1);
    }

按理来说,如果tl对象实例占比大小不大,在经过逃逸分析后,会优先进行栈上分配,那么当栈帧被弹出时,该对象理应被直接回收掉,那么这里实际上并不会,因为什么呢?

因为当前线程的table中存在entry的key引用着当前tl对象 :

如何实现线程池之间的数据透传 ?_第4张图片
但是此时应用程序本身已经失去了对tl对象实例的引用,按照道理来说tl是需要被回收掉的,如果不回收,那就等于发生了内存泄漏,因此这里Entry本身就必须采用弱引用实现,这样才能在GC扫描到当前对象时,将当前tl对象实例进行回收。

对象只存在弱引用,说明对象目前只被弱引用对象实例所引用,软引用和虚引用含义也算如此,这一点弄清楚很重要。

从下图也能看出,但是val并没有被回收掉,严格来说也算是内存泄漏,只有等到当前线程的ThreadLocalMap后面get和set过程中,进行探测式清理和启发式清理时,才会被回收掉 :
如何实现线程池之间的数据透传 ?_第5张图片
这里还有一点需要注意,如果把上述案例改为如下示例,此时ThreadLocal并没有被当前线程所使用到,因此也就不会主动注册到当前线程内部的ThreadLocalMap中去,也就不存在ThreadLocalMap中的key对当前ThreadLocal实例的引用关系了:

    private static void threadLocalGCTest(){
        ThreadLocal<Integer> tl = new ThreadLocal<>();
    }

因此也就无需考虑回收问题了。


那为什么TTL要采用WeakHashMap来保存当前线程使用到的TTL实例呢?

如何实现线程池之间的数据透传 ?_第6张图片
这里原因其实是一致的,如果TTL对象实例丢失了应用程序的强引用关联,那么必须确保TTL能够被回收掉,具体场景还是如下所示:

如何实现线程池之间的数据透传 ?_第7张图片
此时TTL就不止当前线程ThreadLocalMap中key的弱引用了,还多了一个全局注册表的引用,所以必须将该全局注册表对象也设置为弱引用实现。


如何保存捕获的数据

第一个问题搞清楚了,下面来看第二个问题: 我们应该使用什么样的数据结构来保存被捕获的数据呢 ?

这个问题我们需要回到TtlRunnable的实现中来,在TtlRunnable的构造函数中执行了第一阶段的捕获任务:

    private TtlRunnable(Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
                                                // 执行一阶段捕获任务
        this.capturedRef = new AtomicReference<>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

capture是Transmitter类的静态构造方法,从类名不难猜测出,TTL使用该类来保存被捕获的数据,下面来看看它的capture方法实现:

        public static Object capture() {
            final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = newHashMap(transmitteeSet.size());
            // 1. transmitteeSet 集合是什么呢 ? 为什么这里不直接从Holder集合取出当前线程使用到的所有TTL呢 ?
            for (Transmittee<Object, Object> transmittee : transmitteeSet) {
                 // 2. transmittee2Value 为什么要保存这样的映射关系 ? transmittee.capture() 该方法捕获了什么数据呢 ?
                 transmittee2Value.put(transmittee, transmittee.capture());
                 ...            
            }
            // 3. 这里返回的一定就是被捕获的数据了,那具体又是如何保存的呢?
            return new Snapshot(transmittee2Value);
        }

在阅读完上面这段代码后,相信大家应该都存在以上三个疑惑,那么下面我们就来一一探索一下吧。

首先,Transmittee类本身负责完成捕获,回放和恢复三件事情,如下图所示:

如何实现线程池之间的数据透传 ?_第8张图片

在Transmittee类初始化时,会向transmitteeSet集合中注册两个Transmittee对象实例,

        private static final Set<Transmittee<Object, Object>> transmitteeSet = new CopyOnWriteArraySet<>();

        static {
            registerTransmittee(ttlTransmittee);
            registerTransmittee(threadLocalTransmittee);
        }

ttlTransmittee 负责完成上图的捕获和回放过程,而因为只有TTL具备Transmittable的能力,所以为了让那些普通的ThreadLocal也能享受到Transmittable的能力,就有了threadLocalTransmittee。

关于threadLocalTransmittee这块不是重点,大家可以自行查看其实现,比较容易理解,所以就直接跳过了。

下面我们来看一下ttlTransmittee类的capture方法是如何从Holder中获取到当前线程所有的TTL,然后进行保存的:

        private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
                new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
                        // 1. 负责保存capture结果的集合
                        final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = newHashMap(holder.get().size());
                        // 2. 遍历Holder集合中保存的TTL,将其保存到capture集合中
                        // 这里的映射关系为: 
                        for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
                            // 获取当前线程在当前TTL中保存的数据
                            ttl2Value.put(threadLocal, threadLocal.copyValue());
                        }
                        return ttl2Value;
                    }
                    
                    ...
     }                 

capture 捕获得到的结果映射集合如下图所示:
如何实现线程池之间的数据透传 ?_第9张图片
当依次处理完所有Transmittee后,当前线程本时刻上下文快照数据会被保存到Snapshot对象中,然后返回给TtlRunnable对象保存:

        private static class Snapshot {
            final HashMap<Transmittee<Object, Object>, Object> transmittee2Value;

            public Snapshot(HashMap<Transmittee<Object, Object>, Object> transmittee2Value) {
                this.transmittee2Value = transmittee2Value;
            }
        }

如何实现线程池之间的数据透传 ?_第10张图片


save 和 replay

重放阶段我们需要将已经捕获到的之前线程的上下文快照重放到当前线程上下文中,重放前我们需要保存当前线程的上下文快照,以便执行完当前runnable任务后,进行恢复:

save和replay两阶段在TLL实现中是紧密相连的,因此TTL中把这两个阶段合二为一,统称为了replay,同时capture,replay,restore三个阶段也缩称为CRR。

    @Override
    public void run() {
        // 1. 获取已经捕获到的之前线程的上下文快照
        final Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
        // 2. 将之前线程的上下文快照重放到当前线程上下文中,同时返回当前线程上下文快照
        final Object backup = replay(captured);
        try {
            // 3. 执行目标任务
            runnable.run();
        } finally {
            // 4. 利用backup快照,恢复当前线程之前的上下文环境
            restore(backup);
        }
    }

本节我们重点关注replay方法的实现,该方法分为两个阶段:

  1. 保存当前线程上下文快照
  2. 应用之前线程的上下文快照
        public static Object replay(@NonNull Object captured) {
            // 1. 获取之前线程的快照数据
            final Snapshot capturedSnapshot = (Snapshot) captured;
            // 2. 该集合用于保存当前线程的快照数据
            final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = newHashMap(capturedSnapshot.transmittee2Value.size());
            // 3. 遍历capturedSnapshot
            for (Map.Entry<Transmittee<Object, Object>, Object> entry : capturedSnapshot.transmittee2Value.entrySet()) {
                // 4. 获取transmittee和其对应的快照数据
                Transmittee<Object, Object> transmittee = entry.getKey();
                Object transmitteeCaptured = entry.getValue();
                // 5. 调用transmittee的replay方法进行快照重放,同时返回当前线程的快照,然后保存到transmittee2Value中
                transmittee2Value.put(transmittee, transmittee.replay(transmitteeCaptured));
                ...
            }
            // 6. 返回当前线程的上下文快照
            return new Snapshot(transmittee2Value);
        }

transmittee类的replay是快照保存和重放逻辑实现的关键点,下面我们一起来看看:

public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
    // 1. 保存当前线程快照数据
    final HashMap<TransmittableThreadLocal<Object>, Object> backup = newHashMap(holder.get().size());
    // 2. 遍历当前线程Holder集合中每个TTL
    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        // 3. 依次获取每个TTL
        TransmittableThreadLocal<Object> threadLocal = iterator.next();
        // 4. 保存当前线程的上下文快照
        backup.put(threadLocal, threadLocal.get());

        // 5. 这里是要用之前线程上下文数据覆盖掉当前线程整个上下文数据,所以这里要分为讨论
        // 当前线程使用到之前线程没用到的ttl,那么直接清空ttl中的数据
        // 当前线程使用到了之前线程用到的ttl,那么直接覆盖,覆盖逻辑在循环下面 
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 6. 当前线程使用到了之前线程用到的ttl,那么使用captured进行覆盖
    setTtlValuesTo(captured);

    // 7. 目前runnable或者callable任务执行前,回调ttl对应的接口
    doExecuteCallback(true);
    
    // 8. 返回当前线程的上下文快照数据
    return backup;
}

如何实现线程池之间的数据透传 ?_第11张图片

        private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
            // 遍历captured集合中所有ttl
            for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
                // 取出ttl
                TransmittableThreadLocal<Object> threadLocal = entry.getKey();
                // 把ttl对应的快照值重新设置回ttl中,此时就相当于设置到了当前线程本地空间中
                threadLocal.set(entry.getValue());
            }
        }

整个save 和 replay的过程比较简单,我们下面进入restore环节。


restore

当执行完目标任务后,就需要将当前线程之前的上下文状态进行恢复了,整个过程其实和调用函数类似,由于通用寄存器只存在一套,所以调用过程中就需要把通用寄存器当前状态压入函数栈帧中保存,待函数返回时,再从栈帧中弹出恢复先前运行状态。

这里由于线程本地空间只有一套,所以也需要在任务执行完毕后,恢复原本的上下文环境:

    @Override
    public void run() {
        // 1. 获取已经捕获到的之前线程的上下文快照
        final Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
        // 2. 将之前线程的上下文快照重放到当前线程上下文中,同时返回当前线程上下文快照
        final Object backup = replay(captured);
        try {
            // 3. 执行目标任务
            runnable.run();
        } finally {
            // 4. 利用backup快照,恢复当前线程之前的上下文环境
            restore(backup);
        }
    }

利用backup快照进行恢复的过程其实很简单,下面我们快速来过一遍:

        public static void restore(@NonNull Object backup) {
            // 1. 遍历backup快照中所有Transmittee
            for (Map.Entry<Transmittee<Object, Object>, Object> entry : ((Snapshot) backup).transmittee2Value.entrySet()) {
                // 2. 获取Transmittee对应的HashMap,里面保存着 
                Transmittee<Object, Object> transmittee = entry.getKey();
                Object transmitteeBackup = entry.getValue();
                // 3. 调用Transmittee的restore方法完成恢复过程
                transmittee.restore(transmitteeBackup);
                ...
            }
        }

transmittee类的restore是上下文恢复的关键点,下面我们一起来看看:

                    @Override
                    public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
                        // 1. runnable方法调用后,执行TTL对应的回调
                        doExecuteCallback(false);
                        // 2. 遍历当前Holder集合中所有TTL
                        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                            // 3. 获取到当前遍历的TTL
                            TransmittableThreadLocal<Object> threadLocal = iterator.next();

                            // 4. 将threadLocal中不存于backup的threadLocal都进行清空
                            if (!backup.containsKey(threadLocal)) {
                                iterator.remove();
                                threadLocal.superRemove();
                            }
                        }

                        // 5. 将backup中的TTL依次进行恢复,该方法上面介绍过,这里不再多说
                        setTtlValuesTo(backup);
                    }

如何实现线程池之间的数据透传 ?_第12张图片


小结

transmittable-thread-local 本身的设计思路不难理解,本文也只是针对TTL的核心流程源码进行了讲解,如果想进一步学习,可以自行拉取TTL项目源码进行学习。

TTL还提供了一种基于Java Agent的无侵入方案实现,感兴趣的小伙伴可以去 github 项目主页了解一波。

本文只是笔者个人观点,如果不正确的地方欢迎在评论区留言指出。

你可能感兴趣的:(#,技术杂谈,java)