通过手撸线程池深入理解其原理(中)

==> 学习汇总(持续更新)
==> 从零搭建后端基础设施系列(一)-- 背景介绍


摘要:上篇实现了简单的无锁线程池,本篇开始实现有锁线程池。先来思考一下,为什么线程池需要锁?在没有锁的线程池中,就算是单线程提交,也可能会涉及到并发的问题,如果是多线程提交任务,这时候出错的概率基本是百分百了。

一、什么是锁?
在开始写代码之前,先简单认识一下锁的本质是什么,先来看一张图。
通过手撸线程池深入理解其原理(中)_第1张图片
中间那个管道就相当于锁,不管外面有多少个线程,这个管道只能一一通过。

二、ThreadPoolV5
线程池参数:

  • workers:工作线程
  • corePoolSize:核心线程数
  • maxPoolSize:最大线程数
  • keepTimeAlive:线程空闲存活的时间
  • TimeUnit:时间单位
  • workerQueues:任务队列
  • RUNNING:线程池是否运行
  • lock:全局锁
  • currentPoolSize:当前线程池大小

代码:

public class ThreadPoolV6 {
    //核心线程数
    private int corePoolSize;

    //最大线程数
    private int maxPoolSize;

    //允许线程的空闲时间
    private long keepTimeAlive;

    //存放工作线程的哈希表
    private HashSet<Worker> workers;

    //线程池是否关闭
    private boolean RUNNING = true;

    //任务队列
    private BlockingDeque<Runnable> workerQueues;

    //全局锁
    private ReentrantLock lock = new ReentrantLock();

    //记录线程池大小
    private int currentPoolSize = 0;

    public ThreadPoolV6(int corePoolSize, int maxPoolSize, long keepTimeAlive, TimeUnit timeUnit, BlockingDeque<Runnable> workerQueues){
        this.corePoolSize = corePoolSize;
        this.maxPoolSize = maxPoolSize;
        this.keepTimeAlive = timeUnit.toNanos(keepTimeAlive);
        this.workers = new HashSet<>(corePoolSize);
        this.workerQueues = workerQueues;
    }
    static int c = 0;
    //执行任务
    public void submit(Runnable task){
        if(RUNNING){
            /*
                1.当前线程数小于核心线程数时,创建新的工作线程处理
                2.当前线程数等于核心线程数时,加入任务队列
                3.当任务队列满时,创建新的工作线程
                4.当工作线程达到最大线程数时,拒绝提交新的任务
             */
            try {
                lock.lock();
                if(currentPoolSize < corePoolSize){
                    System.out.println("核心线程数:" + (++c));
                    addWorker(task);
                    //如果队列满了,会返回false
                } else if(workerQueues.offer(task)){

                } else if(currentPoolSize < maxPoolSize){
                    addWorker(task);
                } else {
                    throw new RuntimeException("线程池已满,拒绝提交任务");
                }
            }finally {
                lock.unlock();
            }
        }
    }

    //关闭线程池
    public void shutdown(){
        try {
            lock.lock();
            RUNNING = false;
            System.out.println("关闭前线程池大小:" + currentPoolSize);
            for (Worker worker : workers) {
                worker.thread.interrupt();
            }
        } finally {
            lock.unlock();
        }
    }

    //创建新的工作线程
    private void addWorker(Runnable task){
        Worker w = new Worker(task);
        workers.add(w);
        w.thread.start();
        ++currentPoolSize;
    }

    //工作线程类
    private class Worker implements Runnable {
        Thread thread;
        Runnable task;

        public Worker(Runnable task){
            this.task = task;
            this.thread = new Thread(this);
        }

        @Override
        public void run() {
            Runnable t = this.task;
            this.task = null;
            while (t != null || (t = getTask()) != null){
                t.run();
                t = null;
            }
            try {
                lock.lock();
                workers.remove(this);
            } finally {
                lock.unlock();
            }
            System.out.println("当前线程:" + Thread.currentThread().getName() + " 退出");
        }

        private Runnable getTask(){
            boolean timeout = false;
            for (;;){
                //如果线程池关闭 并且 工作队列为空,那么可以回收该线程
                if(!RUNNING && workerQueues.isEmpty()) return null;

                try {
                    //如果超时未拿到任务 并且 当前线程数大于核心线程数的时候,就可以回收该线程
                    boolean timed;
                    try {
                        lock.lock();
                        timed = workers.size() > corePoolSize;
                        if(timed && timeout) {
                            --currentPoolSize;
                            return null;
                        }
                    } finally {
                        lock.unlock();
                    }
                    Runnable runnable = timed ? workerQueues.poll(keepTimeAlive, TimeUnit.NANOSECONDS) : workerQueues.take();
                    if(runnable != null){
                        return runnable;
                    }
                    timeout = true;
                } catch (InterruptedException e) {
                    timeout = false;
                }
            }
        }
    }
}

测试代码:

public static void main(String[] args) throws InterruptedException {
    ThreadPoolV6 pool = new ThreadPoolV6(4, 8, 1, TimeUnit.SECONDS, new LinkedBlockingDeque<>(20));
    try {
        for (int i = 0; i < 28; i++) {
            new Thread(() -> pool.submit(() -> {
                System.out.println("当前线程:" + Thread.currentThread().getName() + " 开始");
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println("当前线程:" + Thread.currentThread().getName() + " 结束");
            })).start();
        }
        Thread.sleep(3000);
        for (int i = 0; i < 8; i++) {
            new Thread(() -> pool.submit(() -> {
                System.out.println("当前线程:" + Thread.currentThread().getName() + " 开始");
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println("当前线程:" + Thread.currentThread().getName() + " 结束");
            })).start();
        }
    } finally {
        Thread.sleep(5000);
        //关闭线程池
        pool.shutdown();
    }
}

测试结果:
通过手撸线程池深入理解其原理(中)_第2张图片
问题分析:
看测试结果,加锁后是符合预期的,非核心线程被回收,核心线程常驻内存。接下来分析还有哪些问题

  • submit里的加锁方式,有些粗鲁,相当于把整个方法都给锁了,效率会大打折扣。这样会造成任务提交的时候其实是串行提交的,效率上并无任何提高,还因为加锁解锁而损耗了性能。
  • 再来看getTask里的加锁方式,思想是非常好的,只锁住了一小段执行非常快的关键代码,那就是判断当前线程池大小是否大于核心线程数,如果大于就可以回收,否则只能调用take()方法阻塞直到有新任务来临。
  • 再来看shutdown()方法,和无锁线程池的区别是,多了线程中断,这能解决什么问题呢?首先,在getTask()方法里,take()会阻塞,如果你仅仅只是修改RUNNING的状态,是不能够关闭线程的。所以需要增加一个线程中断,让take()从阻塞中抛异常,然后我们捕获处理即可。又或者是线程正在执行任务,这时候可以根据中断状态,自行决定是立刻中断还是执行完任务再中断。
  • 最后可以看看submit的返回值是void,满足不了需要监控返回值的场景。

根据这四个问题进行改进得到V6版

三、ThreadPoolV6
线程池参数:

  • workers:工作线程
  • corePoolSize:核心线程数
  • maxPoolSize:最大线程数
  • keepTimeAlive:线程空闲存活的时间
  • TimeUnit:时间单位
  • workerQueues:任务队列
  • RUNNING:线程池是否运行
  • lock:全局锁
  • currentPoolSize:当前线程池大小

代码:

public class ThreadPoolV6 {
    //核心线程数
    private int corePoolSize;

    //最大线程数
    private int maxPoolSize;

    //允许线程的空闲时间
    private long keepTimeAlive;

    //存放工作线程的哈希表
    private HashSet<Worker> workers;

    //线程池是否关闭
    private boolean RUNNING = true;

    //任务队列
    private BlockingDeque<Runnable> workerQueues;

    //全局锁
    private ReentrantLock lock = new ReentrantLock();

    //记录线程池大小
    private AtomicInteger currentPoolSize = new AtomicInteger(0);

    public ThreadPoolV8(int corePoolSize, int maxPoolSize, long keepTimeAlive, TimeUnit timeUnit, BlockingDeque<Runnable> workerQueues){
        this.corePoolSize = corePoolSize;
        this.maxPoolSize = maxPoolSize;
        this.keepTimeAlive = timeUnit.toNanos(keepTimeAlive);
        this.workers = new HashSet<>(corePoolSize);
        this.workerQueues = workerQueues;
    }
    static int c = 0;

    public <T> Future<T> submit(Runnable task){
        RunnableFuture<T> ftask = new FutureTask<T>(task, null);
        execute(ftask);
        return ftask;
    }

    public <T> Future<T> submit(Callable<T> task){
        RunnableFuture<T> ftask = new FutureTask<T>(task);
        execute(ftask);
        return ftask;
    }

    //执行任务
    private void execute(Runnable task){
        if(RUNNING){
            /*
                1.当前线程数小于核心线程数时,创建新的工作线程处理
                2.当前线程数等于核心线程数时,加入任务队列
                3.当任务队列满时,创建新的工作线程
                4.当工作线程达到最大线程数时,拒绝提交新的任务
             */
            if(currentPoolSize.get() < corePoolSize && addWorker(task, true)){
                System.out.println("核心线程数:" + (++c));
                //如果队列满了,会返回false
            } else if(workerQueues.offer(task)){

            } else if(currentPoolSize.get() < maxPoolSize && addWorker(task, false)){
                System.out.println("非核心线程数:" + (++c));
            } else {
                throw new RuntimeException("线程池已满,拒绝提交任务");
            }
        }
    }

    //关闭线程池
    public void shutdown(){
        try {
            lock.lock();
            RUNNING = false;
            System.out.println("关闭前线程池大小:" + currentPoolSize);
            for (Worker worker : workers) {
                worker.thread.interrupt();
            }
        } finally {
            lock.unlock();
        }
    }

    //创建新的工作线程
    private boolean addWorker(Runnable task, boolean core){
        for (;;){
            int c = currentPoolSize.get();
            if((core && c < corePoolSize) || !core && c < maxPoolSize){
                if(currentPoolSize.compareAndSet(c, c + 1)){
                    break;
                }
            } else {
                return false;
            }
        }
        try {
            lock.lock();
            Worker w = new Worker(task);
            workers.add(w);
            w.thread.start();
        } finally {
            lock.unlock();
        }
        return true;
    }

    //工作线程类
    private class Worker implements Runnable {
        Thread thread;
        Runnable task;

        public Worker(Runnable task){
            this.task = task;
            this.thread = new Thread(this);
        }

        @Override
        public void run() {
            Runnable t = this.task;
            this.task = null;
            while (t != null || (t = getTask()) != null){
                t.run();
                t = null;
            }
            workers.remove(this);
            System.out.println("当前线程:" + Thread.currentThread().getName() + " 退出");
        }

        private Runnable getTask(){
            boolean timeout = false;
            for (;;){
                //如果线程池关闭 并且 工作队列为空,那么可以回收该线程
                if(!RUNNING && workerQueues.isEmpty()) return null;

                try {
                    //如果超时未拿到任务 并且 当前线程数大于核心线程数的时候,就可以回收该线程
                    boolean timed;
                    try {
                        lock.lock();
                        timed = currentPoolSize.get() > corePoolSize;
                        if(timed && timeout) {
                            currentPoolSize.decrementAndGet();
                            return null;
                        }
                    } finally {
                        lock.unlock();
                    }
                    Runnable runnable = timed ? workerQueues.poll(keepTimeAlive, TimeUnit.NANOSECONDS) : workerQueues.take();
                    if(runnable != null){
                        return runnable;
                    }
                    timeout = true;
                } catch (InterruptedException e) {
                    timeout = false;
                }
            }
        }
    }
}

测试代码1:
同上

测试结果1:
同上

测试代码2:

public static void main(String[] args) throws ExecutionException, InterruptedException {
    ThreadPoolV6 pool = new ThreadPoolV6(10, 50, 1, TimeUnit.SECONDS, new LinkedBlockingDeque<>(30));
    //ThreadPoolExecutor pool = new ThreadPoolExecutor(10, 50, 1, TimeUnit.SECONDS, new LinkedBlockingDeque<>(30));
    long b = System.currentTimeMillis();
    List<Future<Integer>> futures = new ArrayList<>();
    for (int i = 0; i <= 80; i++) {
        final int start = 1000000 * i, end = 1000000 * (i + 1);
        futures.add(pool.submit(() -> solvePrime(start, end)));
    }
    int res = 0;
    for (Future<Integer> future : futures) {
        res += future.get();
    }
    long e = System.currentTimeMillis();
    System.out.println("结果:" + res);
    System.out.println("耗时:" + (e - b) / 1000.0);
}

static int solvePrime(int start, int end){
    int c = 0;
    for (int i = start; i <= end; i++) {
        c = isPrime(i) ? c + 1 : c;
    }
    return c;
}

static boolean isPrime(int num){
    for (int i = 2; i < num; i++) {
        if(num % 2 == 0){
            return false;
        }
    }
    return true;
}

测试结果2:
V6线程池
通过手撸线程池深入理解其原理(中)_第3张图片
java自带线程池
通过手撸线程池深入理解其原理(中)_第4张图片

问题分析:
看测试2,自带线程居然比V6版效率低?其实在某些场景下,自己实现的线程池效率确实比官方的好,但是稳定性,可用性方面肯定是不如官方的,毕竟官方的实现逻辑非常严谨和经过大量测试的。我们来看问题。

  • submit增加了返回值,就是那个熟悉的Future,如果是Runable的任务,get()返回的是null,如果是Callable的任务,就可以返回指定类型的值。

  • 我们重点可以看到addWorker方法,首先返回值从void变成了boolean,这是为什么呢?设想一下,如果有10个任务同时提交,那么execute中的第一行判断是不是都通过了,当前线程池大小是0,都小于核心线程数,所以addWorker都会进去。看下面这段代码,先通过CAS将线程池大小+1后,再进行实际的工作线程创建。

    private boolean addWorker(Runnable task, boolean core){
        for (;;){
            int c = currentPoolSize.get();
            //如果是核心线程,那么当前线程池大小必须小于corePoolSize
            //如果是非核心线程,那么当前线程池大小必须小于maxPoolSize
            if((core && c < corePoolSize) || !core && c < maxPoolSize){
                //将当前线程池大小+1,之后跳出循环,进行工作线程的创建和添加
                if(currentPoolSize.compareAndSet(c, c + 1)){
                    break;
                }
            } else {
                return false;
            }
        }
        ……
        return true;
    }
    
  • 最后我们再来回顾一下execute方法中的分支判断

    //在这里会出现并发addWorker的操作,但是仅仅有corePoolSize个返回true
    if(currentPoolSize.get() < corePoolSize && addWorker(task, true)){
        System.out.println("核心线程数:" + (++c));
        //当前面的addWorker返回false的时候,就会将任务加入队列中
    } else if(workerQueues.offer(task)){
    
       //这里也可能出现并发addWorker的操作,但是仅仅有maxPoolSize - corePoolSize个返回true
    } else if(currentPoolSize.get() < maxPoolSize && addWorker(task, false)){
        System.out.println("非核心线程数:" + (++c));
    } else {
        //当队列满了,并且最大线程数已经达到了,就会执行这个策略
        throw new RuntimeException("线程池已满,拒绝提交任务");
    }
    

四、总结
线程池实现到V6版,已经将其核心思想造出来了,接下来不过是对性能的优化和功能性的扩展了。

  • 线程池最核心最核心的思想是,在并发提交任务的场景下,实现对线程的高效管理和复用。
  • 要理解线程池,先从无锁入手,理解透彻它是如何管理和复用线程去执行新任务的。
  • 要深入理解线程池,需要理解锁,并发等知识,知道为什么需要锁,哪里需要锁,以及如何锁。
  • 模仿现有的轮子的时候,不要一开始就从最难的开始,应该从最简单的开始,然后一步一步的去仿造它的核心功能。

最后,关于并发处理,还是官方的,这里只是将学到的皮毛展示了一下,下一篇将分析java自带的线程池。

你可能感兴趣的:(java)