跨线程池共享的ThreadLocal

背景

在实际开发中,我们经常会用线程池处理大量任务,但是线程池的使用会让线程变量ThreadLocal无法访问,会很不爽.
举栗,当我们想提高性能,用线程池同时调用多个服务,又不想修改原本代码,实现无侵入的特性,就会很有用.

原理

我们希望使ThreadLocal线程变量跨线程共享,这就要打破jdk提供的访问限制.
ThreadLocal的线程隔离是通过在每个线程内部维护一个ThreadLocalMap的映射表,每次获取都是从当前线程或者父线程的map中(对于InheritableThreadLocal)取值,从而实现的线程间变量访问的隔离.

// ThreadLocal 的部分源码
// 获取线程的ThreadLocalMap 
ThreadLocalMap getMap(Thread t) {
  return t.threadLocals;
}

// 先获取线程的ThreadLocalMap,再往对应的map中设置值
public void set(T value) {
	Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

为此,我们可以通过维护一个静态变量,记录下当前线程所使用的需要跨线程共享的ThreadLocal表,然后再创建线程运行上下文复制线程变量,等线程运行时再其前后以需要的线程变量替换,运行完之后再还原.

// 用该结构包围实际运行的方法
public void run() {
	Map<MyThreadLocal<Object>, Object> replace = null;
    try {
        replace = replace();
        // 设置上下文
        runnable.run();
    } catch (Exception e) {
        e.printStackTrace();
    } finally {
        // 还原上下文
        restore(replace);
    }
}

测试代码

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * @author Lion Zhou
 * @date 2022/9/15
 */
public class Test {
    public static void test5() {
        ExecutorService executorService = Executors.newFixedThreadPool(1);
        // 用一个空任务让线程池创建好线程
        executorService.submit(() -> {
        });

        // 使用我们定义好的线程变量
        MyThreadLocal<Integer> mtl1 = new MyThreadLocal<>();
        mtl1.set(333);

        executorService.submit(MyThreadLocalContext.go(() -> {
            System.out.print("1:");
            System.out.println(mtl1.get());
            // 修改线程变量,因为是副本,不影响其他线程中的值
            mtl1.set(111);
        }));

        executorService.submit(() -> {
            System.out.print("2:");
            // 正常使用为 null
            System.out.println(mtl1.get());
        });

        executorService.submit(MyThreadLocalContext.go(() -> {
            System.out.print("3:");
            // 还是 333
            System.out.println(mtl1.get());
        }));

        executorService.shutdown();
        System.out.println("end:" + mtl1.get());
    }

    public static void main(String[] args) {
        test5();
    }
}

源码

import java.util.WeakHashMap;

/**
 * @author Lion Zhou
 * @date 2022/9/15
 */
public class MyThreadLocal<T> extends InheritableThreadLocal<T> {

    // 维护每个线程所持有的 MyThreadLocal 为后续跨线程传递使用
    static InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = new InheritableThreadLocal<>();

    @Override
    public T get() {
        // 直接调用原本的 get 方法
        T t = super.get();
        if (null == t && null != holder.get()) {
            // 对应key的值已经不存在了,删除当前的持有数据
            holder.get().remove(this);
        }
        return t;
    }

    @Override
    public void set(T value) {
        super.set(value);
        if (holder.get() == null) {
            holder.set(new WeakHashMap<>(8));
        }
        holder.get().put((MyThreadLocal<Object>) this, null);
    }

    @Override
    public void remove() {
        super.remove();
        if (holder.get() != null) {
            holder.get().remove(this);
        }
    }
}
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.WeakHashMap;

/**
 * @author Lion Zhou
 * @date 2022/9/15
 */
public class MyThreadLocalContext {

    public static Runnable go(Runnable runnable) {
        InheritableThreadLocal<WeakHashMap<MyThreadLocal<Object>, Object>> holder = MyThreadLocal.holder;

        Map<MyThreadLocal<Object>, Object> map = Collections.emptyMap();
        if (null != holder.get()) {
            map = new WeakHashMap<>(holder.get().size());
//            System.out.println("start");
            for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.get().entrySet()) {
//                System.out.println(entry.getKey().get());
                map.put(entry.getKey(), entry.getKey().get());
            }
//            System.out.println("end");
        }
        return new Context(map, runnable);
    }

    public static class Context implements Runnable {
        Map<MyThreadLocal<Object>, Object> holder;
        Runnable runnable;

        public Context(Map<MyThreadLocal<Object>, Object> holder, Runnable runnable) {
            this.holder = holder;
            this.runnable = runnable;
        }

        public Map<MyThreadLocal<Object>, Object> replace() {
            // 保留原本的线程本地变量
            Map<MyThreadLocal<Object>, Object> replace = new WeakHashMap<>();

            // 将复制过来的值重新赋值给当前上下文环境
//            System.out.println("context start");
            // 上下文切换
            for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
//                System.out.println(String.format("old: %s, new: %s", Optional.ofNullable(entry.getKey().get()).orElse("null").toString(),
//                        entry.getValue()));

                // 保存 线程本地变量 的现场
                replace.put(entry.getKey(), entry.getKey().get());
                // 替换需要的上下文
                entry.getKey().set(entry.getValue());
            }
//            System.out.println("context end");
            return replace;
        }

        public void restore(Map<MyThreadLocal<Object>, Object> restore) {
            if (null == restore) {
                return;
            }
            for (Map.Entry<MyThreadLocal<Object>, Object> entry : holder.entrySet()) {
                // 原本的值
                Object old = restore.get(entry.getKey());
                if (null == old) {
                    // 原本就为null
                    entry.getKey().remove();
                } else {
                    entry.getKey().set(old);
                }
            }
        }

        @Override
        public void run() {
            Map<MyThreadLocal<Object>, Object> replace = null;
            try {
                replace = replace();
                // 设置上下文
                runnable.run();
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                // 还原上下文
                restore(replace);
            }
        }
    }

}

总结

源码名字不好听,见谅.
代码很简单,只是为了演示,实际还存在一些问题,比如在替换上下文时没有使用 deepcopy等.
WeakHashMap使用就是基本问题了,因为线程变量是跨线程的,并非线程独有值,因此不能破坏原本变量的生命周期(由此导致内存泄露),所以要用弱引用.

相关资料

  • 阿里开源的线程间上下文传递解决方案 支持编程和java agent的形式

你可能感兴趣的:(Java,java)