之前写过一篇博客:基于trace_id的链路追踪(含Feign、Hystrix、线程池等场景),主要介绍在微服务体系架构中,如何实现分布式系统的链路追踪的博客,其中主要实现了以下几种场景:
其中,还缺失了一种较为常见的场景,那就是Java中常用的线程池实现:ForkJoinPool。
尤其Java 8提供的 Stream并行流
采用了 ForkJoinPool
作为默认实现,当我们基于并行流做一些业务操作时,日志的链路追踪往往很容易在这里出现断层的情况。
本文将探讨如何基于trace_id实现ForkJoinPool的链路追踪,以提升系统的可追溯性。
ForkJoinPool是Java提供的一种线程池实现,特别适用于处理递归分解的任务。它采用了工作窃取(Work-Stealing)算法,通过将任务分解为更小的子任务并将其分配给空闲线程执行,从而实现了任务的并行执行。
为了实现基于trace_id的链路追踪,我们可以通过以下步骤进行设计:
package com.github.jesse.l2cache.util.pool;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;
/**
* 自定义 {@link ForkJoinPool},扩展MDC内容,以便链路追踪
*
* @author chenck
* @date 2021/5/11 14:48
*/
public class MdcForkJoinPool extends ForkJoinPool {
/**
* max #workers - 1
*/
public static final int MAX_CAP = 0x7fff;
/**
* the default parallelism level
*/
public static final int DEFAULT_PARALLELISM = Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors());
/**
* the default thread name prefix
*/
public static final String DEFAULT_THREAD_NAME_PREFIX = "MdcForkJoinPool";
/**
* Sequence number for creating workerNamePrefix.
*/
private static int poolNumberSequence;
/**
* Returns the next sequence number. We don't expect this to
* ever contend, so use simple builtin sync.
*/
private static final synchronized int nextPoolId() {
return ++poolNumberSequence;
}
/**
* Common (static) pool.
*/
static final MdcForkJoinPool mdcCommon = new MdcForkJoinPool();
public static MdcForkJoinPool mdcCommonPool() {
return mdcCommon;
}
// constructor
public MdcForkJoinPool() {
this(DEFAULT_PARALLELISM, DEFAULT_THREAD_NAME_PREFIX);
}
public MdcForkJoinPool(int parallelism) {
this(parallelism, DEFAULT_THREAD_NAME_PREFIX);
}
public MdcForkJoinPool(String threadNamePrefix) {
this(DEFAULT_PARALLELISM, threadNamePrefix);
}
public MdcForkJoinPool(int parallelism, String threadNamePrefix) {
this(parallelism, new LimitedThreadForkJoinWorkerThreadFactory(parallelism, threadNamePrefix + "-" + nextPoolId()), null, false);
}
/**
* Creates a new MdcForkJoinPool.
*
* @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
* @param factory the factory for creating new threads. For default value, use
* {@link #defaultForkJoinWorkerThreadFactory}.
* @param handler the handler for internal worker threads that terminate due to unrecoverable errors encountered
* while executing tasks. For default value, use {@code null}.
* @param asyncMode if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
* joined. This mode may be more appropriate than default locally stack-based mode in applications
* in which worker threads only process event-style asynchronous tasks. For default value, use
* {@code false}.
*/
public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, Thread.UncaughtExceptionHandler handler, boolean asyncMode) {
super(parallelism, factory, handler, asyncMode);
}
// Execution methods
@Override
public <T> T invoke(ForkJoinTask<T> task) {
if (task == null) {
throw new NullPointerException();
}
return super.invoke(new ForkJoinTaskMdcWrapper<T>(task));
}
@Override
public void execute(ForkJoinTask<?> task) {
if (task == null) {
throw new NullPointerException();
}
super.execute(new ForkJoinTaskMdcWrapper<>(task));
}
// AbstractExecutorService methods
@Override
public void execute(Runnable task) {
if (task == null) {
throw new NullPointerException();
}
super.execute(new RunnableMdcWarpper(task));
}
@Override
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
if (task == null) {
throw new NullPointerException();
}
return super.submit(new ForkJoinTaskMdcWrapper<T>(task));
}
@Override
public <T> ForkJoinTask<T> submit(Callable<T> task) {
if (task == null) {
throw new NullPointerException();
}
return super.submit(new CallableMdcWrapper(task));
}
@Override
public <T> ForkJoinTask<T> submit(Runnable task, T result) {
if (task == null) {
throw new NullPointerException();
}
return super.submit(new RunnableMdcWarpper(task), result);
}
@Override
public ForkJoinTask<?> submit(Runnable task) {
if (task == null) {
throw new NullPointerException();
}
return super.submit(new RunnableMdcWarpper(task));
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) {
if (tasks == null) {
throw new NullPointerException();
}
Collection<Callable<T>> wrapperTasks = new ArrayList<>();
for (Callable<T> task : tasks) {
wrapperTasks.add(new CallableMdcWrapper(task));
}
return super.invokeAll(wrapperTasks);
}
}
package com.github.jesse.l2cache.util.pool;
import org.slf4j.MDC;
import java.util.Map;
import java.util.concurrent.Callable;
/**
* @author chenck
* @date 2021/5/11 17:09
*/
public class CallableMdcWrapper<T> implements Callable<T> {
private static final long serialVersionUID = 1L;
Callable<T> callable;
Map<String, String> contextMap;
public CallableMdcWrapper(Callable<T> callable) {
this.callable = callable;
this.contextMap = MDC.getCopyOfContextMap();
}
@Override
public T call() throws Exception {
Map<String, String> oldContext = MdcUtil.beforeExecution(contextMap);
try {
return callable.call();
} finally {
MdcUtil.afterExecution(oldContext);
}
}
}
package com.github.jesse.l2cache.util.pool;
import org.slf4j.MDC;
import java.util.Map;
/**
* Runnable 包装 MDC
*
* @author chenck
* @date 2020/9/23 19:37
*/
public class RunnableMdcWarpper implements Runnable {
private static final long serialVersionUID = 1L;
Runnable runnable;
Map<String, String> contextMap;
Object param;
public RunnableMdcWarpper(Runnable runnable) {
this.runnable = runnable;
this.contextMap = MDC.getCopyOfContextMap();
}
public RunnableMdcWarpper(Runnable runnable, Object param) {
this.runnable = runnable;
this.contextMap = MDC.getCopyOfContextMap();
this.param = param;
}
@Override
public void run() {
Map<String, String> oldContext = MdcUtil.beforeExecution(contextMap);
try {
runnable.run();
} finally {
MdcUtil.afterExecution(oldContext);
}
}
public Object getParam() {
return param;
}
}
package com.github.jesse.l2cache.util.pool;
import org.slf4j.MDC;
import java.util.Map;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
/**
* @author chenck
* @date 2021/5/11 16:56
* @see https://stackoverflow.com/questions/36026402/how-to-use-mdc-with-forkjoinpool
*/
public class ForkJoinTaskMdcWrapper<T> extends ForkJoinTask<T> {
private static final long serialVersionUID = 1L;
/**
* If non-null, overrides the value returned by the underlying task.
*/
private final AtomicReference<T> override = new AtomicReference<>();
private ForkJoinTask<T> task;
private Map<String, String> newContext;
public ForkJoinTaskMdcWrapper(ForkJoinTask<T> task) {
this.task = task;
this.newContext = MDC.getCopyOfContextMap();
}
@Override
public T getRawResult() {
T result = override.get();
if (result != null) {
return result;
}
return task.getRawResult();
}
@Override
protected void setRawResult(T value) {
override.set(value);
}
@Override
protected boolean exec() {
Map<String, String> oldContext = MdcUtil.beforeExecution(newContext);
try {
task.invoke();
return true;
} finally {
MdcUtil.afterExecution(oldContext);
}
}
}
package com.github.jesse.l2cache.util.pool;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
/**
* 自定义ForkJoinWorkerThread,用于限制ForkJoinPool中创建的最大线程数
*
* @author chenck
* @date 2023/5/6 13:49
*/
public class LimitedThreadForkJoinWorkerThread extends ForkJoinWorkerThread {
protected LimitedThreadForkJoinWorkerThread(ForkJoinPool pool) {
super(pool);
setPriority(Thread.NORM_PRIORITY); // 设置线程优先级
setDaemon(false); // 设置是否为守护线程
}
protected LimitedThreadForkJoinWorkerThread(ForkJoinPool pool, String threadName) {
super(pool);
setPriority(Thread.NORM_PRIORITY); // 设置线程优先级
setDaemon(false); // 设置是否为守护线程
setName(threadName);
}
}
package com.github.jesse.l2cache.util.pool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 自定义ForkJoinWorkerThreadFactory,用于限制ForkJoinPool中创建的最大线程数,并复用当前的ForkJoinPool的线程
*
* @author chenck
* @date 2023/5/6 13:48
*/
public class LimitedThreadForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
protected static Logger logger = LoggerFactory.getLogger(LimitedThreadForkJoinWorkerThreadFactory.class);
/**
* 最大线程数
*/
private final int maxThreads;
/**
* 线程名称前缀
*/
private String threadNamePrefix;
/**
* 当前线程数
*/
private final AtomicInteger threadCount = new AtomicInteger(0);
public LimitedThreadForkJoinWorkerThreadFactory(int maxThreads) {
this.maxThreads = maxThreads;
}
public LimitedThreadForkJoinWorkerThreadFactory(int maxThreads, String threadNamePrefix) {
this.maxThreads = maxThreads;
this.threadNamePrefix = threadNamePrefix;
}
/**
* 限制了线程数量并复用当前的ForkJoinPool的线程
*/
@Override
public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
int count = threadCount.incrementAndGet();
// 如果当前线程数量小于等于最大线程数,则创建新线程,并将threadCount+1
if (count <= maxThreads) {
if (null == threadNamePrefix || "".equals(threadNamePrefix.trim())) {
return new LimitedThreadForkJoinWorkerThread(pool);
} else {
// 使用自定义线程名称
return new LimitedThreadForkJoinWorkerThread(pool, threadNamePrefix + "-worker-" + count);
}
}
// 如果当前线程数量超过最大线程数,则不创建新线程,并将threadCount-1
threadCount.decrementAndGet();
if (logger.isDebugEnabled()) {
logger.debug("Exceeded maximum number of threads");
}
return null;// 不创建新线程
}
}
package com.github.jesse.l2cache.util.pool;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
/**
* Java 8中的默认并行流使用公共ForkJoinPool,如果提交任务时公共池线程耗尽,会导致任务延迟执行。
*
* CPU密集型:如果在ForkJoinPool中填充的任务,执行时间足够短,且CPU的可用能力足够,那么将不会出现上述延迟的问题。(ForkJoinPool的大多数使用场景)
* I/O密集型:如果在ForkJoinPool中填充的任务,执行时间足够长,且是不受CPU限制的I/O任务,那么任务将延迟执行,并出现瓶颈。
* 小结:ForkJoinPool 最适合的是CPU密集型的任务,如果存在 I/O,线程间同步,sleep() 等会造成线程长时间阻塞的情况时,最好配合使用 ManagedBlocker。
*
* 对I/O阻塞型任务提供一个ManagedBlocker,让ForkJoinPool知道当前任务即将阻塞,因此需要创建新的`备用线程`来执行新提交的任务.
*
* 【问题】通过ManagedBlocker来管理阻塞时,最大正在运行的线程数限制为32767,如果不限制新创建的线程数量,可能导致oom。如何控制ForkJoinPool中新创建的最大备用线程数?
* 【分析】
* 1、ForkJoinPool.common.commonMaxSpares 表示 tryCompensate 中`备用线程`创建的限制,默认为256
* 2、上面这个参数,只能针对commonPool进行限制,并且tryCompensate方法不一定能会命中该限制,若未命中该限制,则可能无限制的创建`备用线程`来避免阻塞,最终还是可能出现oom
* 3、ManagedBlocker将最大正在运行的线程数限制为32767.尝试创建大于最大数目的池导致IllegalArgumentException,只有当池被关闭或内部资源耗尽时,此实现才会拒绝提交的任务(即通过抛出RejectedExecutionException )。
* 【方案】
* 在管理阻塞时,通过自定义 {@LimitedThreadForkJoinWorkerThreadFactory} 来限制ForkJoinPool最大可创建的线程数,并复用当前的ForkJoinPool的线程,以此来避免无限制的创建`备用线程`
*
* @author chenck
* @date 2023/5/5 18:30
*/
public class MyManagedBlocker implements ForkJoinPool.ManagedBlocker {
private Function function;
private Object key;
private Object result;
private boolean done = false;
public MyManagedBlocker(Object key, Function function) {
this.key = key;
this.function = function;
}
@Override
public boolean block() throws InterruptedException {
result = function.apply(key);
done = true;
return false;
}
@Override
public boolean isReleasable() {
return done;
}
public Object getResult() {
return result;
}
}
package com.github.jesse.l2cache.util.pool;
import org.slf4j.MDC;
import java.util.Map;
/**
* @author chenck
* @date 2021/5/11 17:00
*/
public class MdcUtil {
/**
* Invoked before running a task.
*
* @param newMdcContext the new MDC context
* @return the old MDC context
*/
public static Map<String, String> beforeExecution(Map<String, String> newMdcContext) {
Map<String, String> oldMdcContext = MDC.getCopyOfContextMap();
if (newMdcContext == null) {
MDC.clear();
} else {
MDC.setContextMap(newMdcContext);
}
return oldMdcContext;
}
/**
* Invoked after running a task.
*
* @param oldMdcContext the old MDC context
*/
public static void afterExecution(Map<String, String> oldMdcContext) {
if (oldMdcContext == null) {
MDC.clear();
} else {
MDC.setContextMap(oldMdcContext);
}
}
}
基于trace_id的链路追踪是提升分布式系统可追溯性的关键技术之一。
通过在任务中传递和记录trace_id信息,并结合日志和监控系统,开发人员可以更好地了解请求的流转路径和系统性能状况,从而快速定位和解决问题。
在实际应用中,需要根据具体的业务场景和性能要求,灵活选择追踪策略和工具,以实现最佳的性能和可追溯性的平衡。
参考文献: