Java 实现 优先级线程池

本文使用 ThreadPoolExecutor 实现一个 带优先级的线程池


最近做一个PPT转PDF的功能, 调用 office 的另存为, 时间较长, (大约2S转一个文件), 而且只能单线程来跑, 项目要求批量转好并发邮件, 如果用户手动点击的生成PDF则应该尽快生成, 不能等批量转好后再让用户下载.所以就实现了一个有优先级的线程池任务队列. 其实正常的实现方式是使用优先级队列(java.util.PriorityQueue / java.util.concurrent.PriorityBlockingQueue)这种方式没办法同步的获取结果, 编程上有点复杂, java.util.concurrent.ThreadPoolExecutor 可以 public Future submit(Callable task); 使用Future.get(), 阻塞线程, 等待结果, 来实现同步调用.

public class PriorityThreadPoolExecutor extends ThreadPoolExecutor;

实现方法很简单, 继承 ThreadPoolExecutor 使用 PriorityBlockingQueue 优先级队列. PriorityBlockingQueue 有个坑就是.

Operations on this class make no guarantees about the ordering of elements with equal priority.

如果优先级相同,不能确定顺序. 

实际测试下来的结果是, 如果优先级相同则执行顺序跟插入顺序相反, 这就尴尬了, 着还是FIFO队列吗? 官网给了解决方式.对每一个队列元素编号, 照抄就可以了. 限制就是队列历史总个数不能超过 Long 个. 实现一个Comparable 的类

class PriorityRunnable> implements Runnable, Comparable>;

重载线程池的添加任务的方法, 追加一个参数. 如果使用基类的方法, 优先级为 0 .

    public void execute(Runnable command, int priority);
    public  Future submit(Callable task, int priority);
    public  Future submit(Runnable task, T result, int priority);
    public Future submit(Runnable task, int priority);

最终代码如下

package wang.lcs.sys.util;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

public class PriorityThreadPoolExecutor extends ThreadPoolExecutor {

    private static final Logger log = LoggerFactory.getLogger(PriorityThreadPoolExecutor.class);

    private ThreadLocal local = new ThreadLocal() {
        @Override
        protected Integer initialValue() {
            return 0;
        }
    };

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue());
    }

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory);
    }

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, RejectedExecutionHandler handler) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), handler);
    }

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory, handler);
    }

    protected static PriorityBlockingQueue getWorkQueue() {
        return new PriorityBlockingQueue();
    }

    @Override
    public void execute(Runnable command) {
        int priority = local.get();
        try {
            this.execute(command, priority);
        } finally {
            local.set(0);
        }
    }

    public void execute(Runnable command, int priority) {
        super.execute(new PriorityRunnable(command, priority));
    }

    public  Future submit(Callable task, int priority) {
        local.set(priority);
        return super.submit(task);
    }

    public  Future submit(Runnable task, T result, int priority) {
        local.set(priority);
        return super.submit(task, result);
    }

    public Future submit(Runnable task, int priority) {
        local.set(priority);
        return super.submit(task);
    }

    protected static class PriorityRunnable> implements Runnable, Comparable> {
        private final static AtomicLong seq = new AtomicLong();
        private final long seqNum;
        Runnable run;
        private int priority;

        public PriorityRunnable(Runnable run, int priority) {
            seqNum = seq.getAndIncrement();
            this.run = run;
            this.priority = priority;
        }

        public int getPriority() {
            return priority;
        }

        public void setPriority(int priority) {
            this.priority = priority;
        }

        public Runnable getRun() {
            return run;
        }

        @Override
        public void run() {
            this.run.run();
        }

        @Override
        public int compareTo(PriorityRunnable other) {
            int res = 0;
            if (this.priority == other.priority) {
                if (other.run != this.run) {// ASC
                    res = (seqNum < other.seqNum ? -1 : 1);
                }
            } else {// DESC
                res = this.priority > other.priority ? -1 : 1;
            }
            return res;
        }
    }
}

下面是测试用例

package wang.lcs.sys.util;

import org.junit.Assert;
import org.junit.Test;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.*;

public class PriorityThreadPoolExecutorTest {

    @Test
    public void testDefault() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[20];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            int index = i;
            futures[i] = pool.submit(new Callable() {
                @Override
                public Object call() throws Exception {
                    Thread.sleep(10);
                    buffer.append(index + ", ");
                    return null;
                }
            });
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }
        System.out.println(buffer);
        assertEquals("0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, ", buffer.toString());
    }

    @Test
    public void testSamePriority() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[10];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            futures[i] = pool.submit(new TenSecondTask(i, 1, buffer), 1);
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }
        System.out.println(buffer);
        assertEquals("01@00, 01@01, 01@02, 01@03, 01@04, 01@05, 01@06, 01@07, 01@08, 01@09, ", buffer.toString());
    }

    @Test
    public void testRandomPriority() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[20];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            int r = (int) (Math.random() * 100);
            futures[i] = pool.submit(new TenSecondTask(i, r, buffer), r);
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }

        buffer.append("01@00");
        System.out.println(buffer);
// 这是个可能的执行结果 80@00, 99@15, 92@04, 83@16, 77@09, 73@14, 72@07, 66@10, 63@02, 58@01, 56@08, 52@05, 47@03, 35@11, 33@12, 32@17, 22@06, 16@13, 11@19, 07@18, 01@00
// 除了最前面的一个(上面的结果)或者两个, 剩余的都是按照优先级执行
        String[] split = buffer.toString().split(", ");
        // 从 2 开始, 因为前面的任务可能已经开始
        for (int i = 2; i < split.length - 1; i++) {
            String s = split[i].split("@")[0];
            assertTrue(Integer.valueOf(s) >= Integer.valueOf(split[i + 1].split("@")[0]));
        }
    }

    public static class TenSecondTask implements Callable {
        private StringBuffer buffer;
        int index;
        int priority;

        public TenSecondTask(int index, int priority, StringBuffer buffer) {
            this.index = index;
            this.priority = priority;
            this.buffer = buffer;
        }

        @Override
        public T call() throws Exception {
            Thread.sleep(10);
            buffer.append(String.format("%02d@%02d", this.priority, index)).append(", ");
            return null;
        }
    }
}

需要说明的是: 使用了 ThreadLocal 类, 减少的代码的复制粘贴

你可能感兴趣的:(simple-code)