在实际应用中,可以通过扩展实现对线程池运行状态的跟踪,
了解线程池的具体使用情况以及每个线程实行耗时信息,输出一些调试的信息,
以帮助系统故障诊断,这对于多线程程序错误排查是很有帮助的。
在JDK默认的ThreadPoolExecutor实现中,提供了空的beforeExecute,afterExecute,terminated实现。三个方法均没有具体实现并定义为protected,就是为了在子类中进行扩展改写,这为记录线程执行耗时情况提供了可能。
ThreadPoolExecutor中三个方法如下:
protected void beforeExecute(Thread t, Runnable r) {
}
protected void afterExecute(Runnable r, Throwable t) {
}
protected void terminated() {
}
ThreadPoolExecutor.Worker是ThreadPoolExecutor的内部类,它是一个实现了Runnable接口的类,ThreadPoolExecutor线程池中的工作线程也正是Worker实例,Worker 的run方法会被线程池以多线程模式异步调用,Worker的run方法直接调用了runWorker,即runWorker会同时被多个线程访问,因此beforeExecute,afterExecute接口也将同时被多线程访问。
Worker内部类的定义以及run方法直接调用runWorker
//Worker类定义如下
private final class Worker
extends AbstractQueuedSynchronizer
implements Runnable
{
/** run()方法调用 */
public void run() {
runWorker(this);
}
//其他属性方法略去
}
ThreadPoolExecutor的实例方法runWorker方法具体执行任务run方法,任务运行前调用了beforeExecute,任务运行后调用afterExecute,runWorker方法具体实现如下:
final void runWorker(Worker w) {
Thread wt = Thread.currentThread();
Runnable task = w.firstTask;
w.firstTask = null;
w.unlock();
boolean completedAbruptly = true;
try {
while (task != null || (task = getTask()) != null) {
w.lock();
if ((runStateAtLeast(ctl.get(), STOP) ||
(Thread.interrupted() &&
runStateAtLeast(ctl.get(), STOP))) &&
!wt.isInterrupted())
wt.interrupt();
try {
//beforeExecute()在任务执行前被调用
beforeExecute(wt, task);
Throwable thrown = null;
try {
task.run();//运行任务
} catch (RuntimeException x) {
thrown = x; throw x;
} catch (Error x) {
thrown = x; throw x;
} catch (Throwable x) {
thrown = x; throw new Error(x);
} finally {
//afterExecute()在任务执行后被调用
afterExecute(task, thrown);
}
} finally {
task = null;
w.completedTasks++;
w.unlock();
}
}
completedAbruptly = false;
} finally {
processWorkerExit(w, completedAbruptly);
}
}
基于以上我们自定义一个线程池继承ThreadPoolExecutor,然后重写beforeExecute,afterExecute,terminated方法实现,在beforeExecute调用时记录当前线程开始时间并把它保存到一个ThreadLocal变量中,然后在afterExecute调用时读取,并计算当前线程耗时,在afterExecute调用时一并记录维护原子变量线程总数与总耗时信息,在terminated调用时计算平均耗时。
自定义线程池代码如下:
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
@Slf4j
public class MyThreadPool extends ThreadPoolExecutor {
private final ThreadLocal<Long> startTime = new ThreadLocal<>();
//记录任务数
private final AtomicLong tasksNum = new AtomicLong();
//记录任务总耗时
private final AtomicLong totalTime = new AtomicLong();
public MyThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
}
@Override
protected void beforeExecute(Thread t, Runnable r) {
super.beforeExecute(t, r);
log.info(String.format("BeforeExecute: ThreadID : %s",Thread.currentThread().getId()));
startTime.set(System.currentTimeMillis());
}
@Override
protected void afterExecute(Runnable r, Throwable t) {
try {
long taskTime = System.currentTimeMillis() - startTime.get();
//任务计数
tasksNum.incrementAndGet();
//任务总计耗时
totalTime.addAndGet(taskTime);
//[Running, --线程运行状态
// pool size = 3, --线程池中工作线程数
// active threads = 3, --线程池中活跃线程数(线程运行中非阻塞状态)
// queued tasks = 4, --线程池待处理的任务数
// completed tasks = 23]--已完成的任务数量
log.info(String.format("线程池相关信息:%s",this.toString()));
log.info(String.format("AfterExecute: ThreadID : %s, taskRunTime=%d(ms)",Thread.currentThread().getId(),taskTime));
} finally {
super.afterExecute(r,t);
}
}
@Override
protected void terminated() {
try {
log.info(String.format("Terminated: avgTaskRuntime=%d(ms)",totalTime.get() / tasksNum.get()));
} finally {
super.terminated();
}
}
}
测试代码如下:
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;
import java.util.concurrent.*;
public class TestMyThreadPool {
public static void main(String[] args) {
MyThreadPool myThreadPool = new MyThreadPool(3, 3, 0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>());
List<MyTask> taskList = new ArrayList<>();
for(int i=0;i<30;i++){
taskList.add(new MyTask());
}
try {
List<Future<Long>> futures = myThreadPool.invokeAll(taskList);
StringJoiner joiner = new StringJoiner("-","[","]");
for (Future<Long> f:futures){
joiner.add(String.valueOf(f.get()));
}
System.out.println("任务执行时间:"+ joiner.toString());
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
} finally {
myThreadPool.shutdown();
}
}
}
//模拟任务 实现Callable接口
class MyTask implements Callable<Long> {
@Override
public Long call() throws Exception {
long time= Math.round(Math.random() * 1000);
Thread.sleep(time);
return time;
}
}
还可以打印线程池的其他相关信息,比如任务总数this.getTaskCount(),核心线程数this.getCorePoolSize()等等。这里直接调用ThreadPoolExecutor 的toString 方法。
toString()打印如下信息
[Running, pool size = 3, active threads = 1, queued tasks = 0, completed tasks = 29]
参数解释
Running 运行状态
pool size = 3 线程池中工作线程数
active threads = 1 线程池中活跃线程数(线程运行中非阻塞状态)
queued tasks = 0 线程池待处理的任务数
completed tasks = 29 已完成的任务数量