《Java动手撸源码》手写实现线程池

《Java动手撸源码》手写实现线程池

文章目录

  • 《Java动手撸源码》手写实现线程池
  • 前言
  • 一、线程池的原理
  • 二、简易版本
  • 三、完善版本
    • 1.类图
    • 2.重点代码分析
      • 2.1 ThreadPool接口分析
      • 2.2 RunableQueue接口分析
      • 2.3 拒绝策略
      • 2.4 BasicThreadPool(重点)
      • 2.5 ThreadPoolTest类代码测试
  • 总结


前言

线程池想必大家都用过,无论是C++还是Java等各种语言里面都有线程池。我们通过对Thread的学习得知,Thread是一个重量级的资源,创建、启动以及销毁都是比较耗费系统资源的,因此对线程的重复利用是非常好的程序设计习惯,价值系统可创建的线程数量是有限的,线程数量和系统性能是一种抛物线的关系,也就是当线程数量到达某个数值时,性能反倒会降低很多,因此对线程的管理,尤其是数量的控制更能直接决定程序的性能。
本作者维护了一个仓库,名称叫Thread,打算在这个仓库里面手写实现Java多线程的一些经典技术,欢迎大家的star,本博文的代码已经上传到了该仓库,在com.thread.threadpool包下。
链接: 仓库地址。欢迎大家的star,您的star是我继续下去的动力。


一、线程池的原理

所谓线程池,通俗的理解为一个池子,池子里面存放着创建好的线程。当有任务提交给线程池执行时,池子中的某个线程会主动的执行该任务。如果池子中的线程数量不够应付数量众多的任务时,则需要自动扩充新的线程到池子中,但是该数量是优先的,就好比池塘的水界限一样。当任务比较少的时候,池子中的线程能够自动回收,释放资源。为了能够异步地提交任务和缓存未被处理的任务,需要有一个任务队列。如下图所示。
《Java动手撸源码》手写实现线程池_第1张图片
通过上面的描述可知,一个完整的线程池应该具有一下几个要素:

  1. 任务队列:用户缓存提交的任务
  2. 线程数量管理功能:一个线程池必须能够很好的管理和控制线程的数量,可通过如下的三个参数来实现,
  • 创建线程池时初始的线程数量init
  • 线程池自动扩充时的最大线程数量max;
  • 线程空闲时需要释放一部分线程,但是也要维护一定数量的核心线程core
    三者的关系是init<= core <= max
  1. 任务拒绝策略,如果线程数量已经达到上限且任务队列已满,则需要有相应的拒绝策略来通知任务
  2. 线程工厂,主要用于个性化定制线程,比如将线程设置为守护线程以及设置线程名称等。
  3. QueueSize:任务队列主要存放提交的Runnable,但是为了防止内存溢出,需要有limit数量对其进行控制。
  4. KeepedAlive时间:改时间主要决定线程各个重要参数自动维护的时间间隔。

二、简易版本

其实这个版本除了不能自动维护线程的数量,其他功能都差不多实现了。而且也比较好理解。

package com.thread.threadpool;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

public class SimpleThreadPool {
     

    private static final int DEFAULT_MAX_THREAD_SIZE = 10;

    private static final LinkedList<Runnable> TASK_QUEUE = new LinkedList<Runnable>();

    private static final String THREAD_POOL_PREFIX = "SIMPLE_THREAD_POOL-";

    private static final int DEFAULT_MAX_TASK_SIZE = 2000;

    private final List<WorkerThread> THREAD_QUEUE = new ArrayList<WorkerThread>();

    private static final DiscardPolicy DEFAULT_DISCARD_POLICY = () -> {
     
        throw new DiscardException("Discard this Task...");
    };

    private int seq = 0;

    private int threadSize;

    private int taskSize;

    private DiscardPolicy discardPolicy;

    private ThreadGroup threadGroup = new ThreadGroup("simpleThreadGroup");

    private volatile boolean isDestory = false;

    public SimpleThreadPool(int threadSize, int taskSize, DiscardPolicy discardPolicy) {
     
        this.threadSize = threadSize;
        this.taskSize = taskSize;
        this.discardPolicy = discardPolicy;
        init();
    }

    public SimpleThreadPool() {
     
        this(DEFAULT_MAX_THREAD_SIZE, DEFAULT_MAX_TASK_SIZE, DEFAULT_DISCARD_POLICY);
    }

    private void init() {
     
        for (int i = 0; i < threadSize; i++) {
     
            WorkerThread WorkerThread = new WorkerThread(threadGroup, THREAD_POOL_PREFIX + seq++);
            WorkerThread.start();
            THREAD_QUEUE.add(WorkerThread);
        }
    }


    public void submit(Runnable runner) throws Exception {
     
        if (isDestory) {
     
            throw new RuntimeException("The thread pool is already destoryed and not allow to submit");
        }
        synchronized (TASK_QUEUE) {
     
            if (TASK_QUEUE.size() > taskSize)
                discardPolicy.discard();
            TASK_QUEUE.addLast(runner);
            TASK_QUEUE.notifyAll();
        }
    }

    public void shutdown() throws InterruptedException {
     
        System.out.println("shutdown");
        while (!TASK_QUEUE.isEmpty()) {
     
            Thread.sleep(10);
        }
        int size = THREAD_QUEUE.size();
        while (size > 0) {
     
            for (WorkerThread task : THREAD_QUEUE) {
     
                if (task.TASK_STATE == TaskState.BLOCK) {
     
                    task.interrupt();
                    task.close();
                    size--;
                } else {
     
                    Thread.sleep(10);
                }
            }
        }
        this.isDestory = true;
        System.out.println("The Thread Pool shutdown...");
    }

    public int getThreadSize() {
     
        return threadSize;
    }

    public int getTaskSize() {
     
        return taskSize;
    }

    private enum TaskState {
     FREE, RUNNING, BLOCK, DEAD}

    private static class DiscardException extends RuntimeException {
     
        public DiscardException(String message) {
     
            super(message);
        }
    }

    private static interface DiscardPolicy {
     
        public void discard() throws DiscardException;
    }


    private static class WorkerThread extends Thread {
     
        private volatile TaskState TASK_STATE = TaskState.FREE;

        public WorkerThread(ThreadGroup threadGroup, String threadName) {
     
            super(threadGroup, threadName);
        }

        @Override
        public void run() {
     
            OUTER:
            while (TASK_STATE != TaskState.DEAD) {
     
                Runnable runner = null;
                synchronized (TASK_QUEUE) {
     
                    while (TASK_QUEUE.size() == 0) {
     
                        try {
     
                            TASK_STATE = TaskState.BLOCK;
                            TASK_QUEUE.wait();
                        } catch (InterruptedException e) {
     
                            //e.printStackTrace();
                            break OUTER;
                        }
                    }
                    runner = TASK_QUEUE.removeFirst();
                }
                if (runner != null) {
     
                    TASK_STATE = TaskState.RUNNING;
                    runner.run();
                    TASK_STATE = TaskState.FREE;
                }
            }
        }

        public void close() {
     
            TASK_STATE = TaskState.DEAD;
        }
    }

    public static void main(String[] args) throws InterruptedException {
     
        SimpleThreadPool simpleThreadPool = new SimpleThreadPool();
        for (int i = 0; i < 40; i++) {
     
            final int j = i;
            try {
     
                simpleThreadPool.submit(() -> {
     
                    System.out.println("The runnable " + j + "be served as " + Thread.currentThread().getName() + " start");
                    try {
     
                        Thread.sleep(1000);
                        System.out.println("The runnable " + j + "be served as " + Thread.currentThread().getName() + " end");
                    } catch (InterruptedException e) {
     
                        e.printStackTrace();
                    }
                });
            } catch (Exception e) {
     
//                e.printStackTrace();
                System.out.println(e);
            }
        }
        Thread.sleep(9000);
        simpleThreadPool.shutdown();
        try {
     
            simpleThreadPool.submit(() -> {
     
                System.out.println("尝试再次提交...");
            });
        } catch (Exception e) {
     
            e.printStackTrace();
        }
    }

}

三、完善版本

1.类图

我没有找到好的UML类图设计工具,然后就把代码写完之后,用IDEA生成的。
《Java动手撸源码》手写实现线程池_第2张图片
如图所示,一共14个类和接口,基本实现了线程池的功能。

2.重点代码分析

2.1 ThreadPool接口分析

《Java动手撸源码》手写实现线程池_第3张图片
代码如下:

package com.thread.threadpool;

public interface ThreadPool {
     

    // 提交任务到线程池
    void execute(Runnable runnable);

    // 关闭线程池
    void shutdown();

    // 获取线程池的初始化大小
    int getInitSize();

    // 获取线程池最大的线程数
    int getMaxSize();

    // 获取线程池核心线程数量
    int getCoreSize();

    // 获取线程池中用于缓存任务队列的大小
    int getQueueSize();

    // 获取线程池活跃的线程数量
    int getActiveCount();

    // 查看线程池是否已经被shutdown
    boolean isShutdown();
}

ThreadPool 接口就是定义了一系列的规范,比如提交任务到线程池,关闭线程池,获取线程池的初始大小、最大支持的线程数、线程池的核心线程数量、线程池缓存任务队列的大小、线程池中活跃的线程数量等。

2.2 RunableQueue接口分析

RunableQueue是任务的缓存队列,任务是做缓存,有任务来的时候进入队列,FIFO先进先出。所以要提供进入队列和弹出队列的方法。
《Java动手撸源码》手写实现线程池_第4张图片

代码如下:

package com.thread.threadpool;

public interface RunableQueue {
     
    // 当有新的任务进来时首先会offer到队列
    void offer(Runnable runnable);

    // 工作线程通过take方法获取Runnable。线程获取过程中可能会抛出异常。
    Runnable take() throws InterruptedException;

    // 获取任务队列中任务的数量
    int size();
}

2.3 拒绝策略

《Java动手撸源码》手写实现线程池_第5张图片
这里不贴代码了,因为很简单,DenyPolicy是一个函数式接口,定义了拒绝策略的接口函数,下面三个是实现类,AbortDenyPolicy的拒绝策略是抛出RunntimeException;DiscardDenyPolicy的策略是直接丢弃当前的任务,并且不做任何处理;RunnerDenyPolicy的拒绝策略是让任务提交者在自己所在的线程中执行任务。

2.4 BasicThreadPool(重点)

《Java动手撸源码》手写实现线程池_第6张图片
代码如下:

package com.thread.threadpool;

import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class BasicThreadPool extends Thread implements ThreadPool {
     

    // 初始化线程数量
    private final int initSize;

    // 线程池的最大线程数量
    private final int maxSize;

    // 线程池核心线程数量
    private final int coreSize;

    // 当前活跃的线程数量
    private int activeCount;

    // 创建线程所需的工厂
    private final ThreadFactory threadFactory;

    // 任务队列
    private final RunableQueue runableQueue;

    // 线程池是否已经被shutdown
    private volatile boolean isShutdown = false;

    // 工作线程队列
    private final Queue<ThreadTask> threadTaskQueue = new ArrayDeque<>();
    // 默认的拒绝策略是丢弃的策略
    private final static DenyPolicy DEFAULT_DENY_POLICY = new DiscardDenyPolicy();

    // 默认的线程工厂实现
    private final static ThreadFactory DEFAULT_THREAD_FACTORY = new DefaultThreadFactory();

    // 默认的存活时间
    private final long keepAliveTime;

    private final TimeUnit timeUnit;

    private static class ThreadTask {
     
        Thread thread;
        WorkThread workThread;

        public ThreadTask(Thread thread, WorkThread workThread) {
     
            this.thread = thread;
            this.workThread = workThread;
        }
    }

    private static class DefaultThreadFactory implements ThreadFactory {
     
        private static final AtomicInteger GROUP_COUNT = new AtomicInteger(1);
        private static final ThreadGroup group = new ThreadGroup("MyThreadPool-" + GROUP_COUNT.getAndDecrement());
        private static AtomicInteger THREAD_COUNTER = new AtomicInteger(0);

        @Override
        public Thread createThread(Runnable runnable) {
     
            return new Thread(group, runnable, "thread-pool-" + THREAD_COUNTER.getAndIncrement());
        }
    }

    public BasicThreadPool(int initSize, int maxSize, int coreSize, ThreadFactory threadFactory,
                           int queueSize, DenyPolicy denyPolicy, long keepAliveTime, TimeUnit timeUnit) {
     
        this.initSize = initSize;
        this.maxSize = maxSize;
        this.coreSize = coreSize;
        this.activeCount = activeCount;
        this.threadFactory = threadFactory;
        this.runableQueue = new LinkedRunableQueue(queueSize, denyPolicy, this);
        this.keepAliveTime = keepAliveTime;
        this.timeUnit = timeUnit;
        init();
    }

    public BasicThreadPool(int initSize, int maxSize, int coreSize, int queueSize) {
     
        this(initSize, maxSize, coreSize, DEFAULT_THREAD_FACTORY, queueSize, DEFAULT_DENY_POLICY, 10, TimeUnit.SECONDS);
    }

    private void newThread() {
     
        WorkThread workThread = new WorkThread(runableQueue);
        Thread thread = this.threadFactory.createThread(workThread);
        ThreadTask threadTask = new ThreadTask(thread, workThread);
        threadTaskQueue.offer(threadTask);
        this.activeCount++;
        thread.start();
    }

    void init() {
     
        start();
        for (int i = 0; i < initSize; i++) {
     
            newThread();
        }
    }


    @Override
    public void execute(Runnable runnable) {
     
        if (this.isShutdown)
            throw new IllegalStateException("The ThreadPool is destory");
        this.runableQueue.offer(runnable);
    }

    @Override
    public void shutdown() {
     
        synchronized (this) {
     
            if (isShutdown) return;
            isShutdown = true;
            threadTaskQueue.forEach(threadTask -> {
     
                threadTask.workThread.stop();
                threadTask.thread.interrupt();
            });
        }
    }

    // 从线程池中移除某个线程
    private void removeThread() {
     
        ThreadTask threadTask = threadTaskQueue.remove();
        threadTask.workThread.stop();
        this.activeCount--;
    }

    @Override
    public void run() {
     
        while (!isShutdown && !interrupted()) {
     
            try {
     
                timeUnit.sleep(keepAliveTime);
            } catch (InterruptedException e) {
     
                isShutdown = true;
                break;
            }
            synchronized (this) {
     
                if (isShutdown) {
     
                    break;
                }
                // 第一次扩容:当前队列中有任务尚未处理,并且activeCount < coreSize
                if (runableQueue.size() > 0 && activeCount < coreSize) {
     
                    // 因为是首次扩容,所以起点就是初试大小
                    for (int i = initSize; i < coreSize; i++) {
     
                        newThread();
                    }
                    continue;//先扩容到coreSize大小
                }
                //第二次扩容:当前的队列中有任务尚未处理,并且activeCount < maxSize则继续扩容
                if (runableQueue.size() > 0 && activeCount < maxSize) {
     
                    // 扩容到coreSize之后,发现队列中还有任务没有得到处理,则继续扩容到maxSize。
                    for (int i = coreSize; i < maxSize; i++) {
     
                        newThread();
                    }
                }
                // 扩容结束:如果任务队列中没有任务,则需要回收部分线程,如果线程当前正在执行着任务,就等任务执行完之后回收。
                if (runableQueue.size() == 0 && activeCount > coreSize) {
     
                    removeThread();
                }
            }
        }
    }

    @Override
    public int getInitSize() {
     
        if (this.isShutdown)
            throw new IllegalStateException("The ThreadPool is destory");
        return initSize;
    }

    @Override
    public int getMaxSize() {
     
        if (this.isShutdown)
            throw new IllegalStateException("The ThreadPool is destory");
        return maxSize;
    }

    @Override
    public int getCoreSize() {
     
        if (this.isShutdown)
            throw new IllegalStateException("The ThreadPool is destory");
        return coreSize;
    }

    @Override
    public int getQueueSize() {
     
        if (this.isShutdown)
            throw new IllegalStateException("The ThreadPool is destory");
        return runableQueue.size();
    }

    @Override
    public int getActiveCount() {
     
        if (this.isShutdown)
            throw new IllegalStateException("The ThreadPool is destory");
        return activeCount;
    }

    @Override
    public boolean isShutdown() {
     
        return this.isShutdown;
    }

}



这个实现类是最复杂的,也是最关键的代码,线程池的实现原理就是:线程池维护了一个缓存队列,这个队列用来存放用户提交的任务,线程池动态的从队列里面获取任务去执行,并且根据任务的数量动态的改变线程池中执行线程的大小。所以基于如上的说明,线程池其实本身也是一个Thread线程,他的执行单元里面的逻辑是动态改变线程池大小的关键。具体大家直接去我的github仓库下载代码,用IDEA打开看一下更直观。链接: 仓库地址。大家方便的话可以给我一个star,您的鼓励是我继续下去的动力,加油。

2.5 ThreadPoolTest类代码测试

ThreadPoolTest类启动了20个任务,并通过打印,可以直观的查看线程池的变化情况。
代码如下:

package com.thread.threadpool;

import java.util.concurrent.TimeUnit;

// 线程池的测试
public class ThreadPoolTest {
     
    public static void main(String[] args) throws InterruptedException {
     
        //定义线程池,初始线程数为2,核心线程数为4,最大线程数为6,任务最多允许1000个任务。
        final ThreadPool threadPool = new BasicThreadPool(2, 6, 4, 1000);
        for (int i = 0; i < 20; i++) {
     
            threadPool.execute(() -> {
     
                try {
     
                    System.out.println(Thread.currentThread().getName() + "is running");
                    TimeUnit.SECONDS.sleep(10);
                    System.out.println(Thread.currentThread().getName() + "is done");
                } catch (InterruptedException e) {
     
                    e.printStackTrace();
                }

            });
        }
        for (; ; ) {
     
            // 不断输出线程池的信息
            System.out.println("getActiveCount:" + threadPool.getActiveCount());
            System.out.println("getQueueSize:" + threadPool.getQueueSize());
            System.out.println("getCore:" + threadPool.getCoreSize());
            System.out.println("getMaxSize:" + threadPool.getMaxSize());
            System.out.println("------------------------------------------------");
            TimeUnit.SECONDS.sleep(5);
        }
    }
}

控制台打印如下,可以看到线程池一开始启动了两个线程进行任务处理,后来经过第一次扩容到coreSize(4)个,第二次扩容到maxSize(6)个,之后线程执行的差不多之后,将线程池的大小回收到了coreSize(4)个。

com.thread.threadpool.ThreadPoolTest
getActiveCount:2
thread-pool-0is running
thread-pool-1is running
getQueueSize:18
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:2
getQueueSize:18
getCore:4
getMaxSize:6
------------------------------------------------
thread-pool-2is running
thread-pool-3is running
thread-pool-0is done
thread-pool-1is done
thread-pool-1is running
thread-pool-0is running
getActiveCount:4
getQueueSize:14
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:14
getCore:4
getMaxSize:6
------------------------------------------------
thread-pool-2is done
thread-pool-2is running
thread-pool-3is done
thread-pool-3is running
thread-pool-4is running
thread-pool-5is running
getActiveCount:6
thread-pool-0is done
thread-pool-0is running
thread-pool-1is done
getQueueSize:9
getCore:4
getMaxSize:6
------------------------------------------------
thread-pool-1is running
getActiveCount:6
getQueueSize:8
getCore:4
getMaxSize:6
------------------------------------------------
thread-pool-2is done
thread-pool-3is done
thread-pool-3is running
thread-pool-2is running
thread-pool-4is done
thread-pool-5is done
thread-pool-5is running
thread-pool-4is running
thread-pool-0is done
thread-pool-1is done
thread-pool-0is running
thread-pool-1is running
getActiveCount:6
getQueueSize:2
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:6
getQueueSize:2
getCore:4
getMaxSize:6
------------------------------------------------
thread-pool-3is done
thread-pool-2is done
thread-pool-3is running
thread-pool-2is running
thread-pool-4is done
thread-pool-5is done
thread-pool-0is done
thread-pool-1is done
getActiveCount:6
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:6
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
thread-pool-3is done
thread-pool-2is done
getActiveCount:5
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:5
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------
getActiveCount:4
getQueueSize:0
getCore:4
getMaxSize:6
------------------------------------------------

Process finished with exit code -1


总结

线程池,还是多线程领域一个非常重要的技术,很值得大家去学习。

你可能感兴趣的:(线程池)