springboot之线程池

什么时候在项目中使用线程池


  • 项目中有的任务耗时长,造成项目一直阻塞运行中。
  • 多线程任务之间不需要协同的关系。
  • 将耗时的任务异步执行,保证主线程的任务可以一直执行下去。
  • 譬如我们再导入大量数据,批量插入的时候,可以采用线程池方式。

线程池基础知识


https://www.jianshu.com/p/f5e2c8b6ed75

springboot中使用自定义线程池


  1. 配置线程池 appliction.yml
thread :
  param:
    corePoolSize: 5
    maxPoolSize: 5
    keepAliveSeconds: 10
    queueCapacity: 500
    allowCoreThreadTimeOut: false
  1. 编写线程池配置类
@EnableAsync
@Data
@Configuration
public class ThreadConfig {

    @Autowired
    private ThreadParamModel threadParam;

    @Bean("asyncExecutor")
    public Executor asyncExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        // 配置核心线程数量
        executor.setCorePoolSize(threadParam.getCorePoolSize());
        // 配置最大线程数
        executor.setMaxPoolSize(threadParam.getMaxPoolSize());
        // 配置队列容量
        executor.setQueueCapacity(threadParam.getQueueCapacity());
        // 配置空闲线程存活时间
        executor.setKeepAliveSeconds(threadParam.getKeepAliveSeconds());
        // 是否允许核心线程超时
        executor.setAllowCoreThreadTimeOut(threadParam.isAllowCoreThreadTimeOut());
        // 设置拒绝策略,直接在execute方法的调用线程中运行被拒绝的任务
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        // 执行初始化
        executor.initialize();
        return executor;
    }
}
  1. 业务方法中使用线程池
    service层:
@Async(value = "asyncExecutor")
@Override
public void insertUserForBatch(List userBOList) throws CommonException {
    log.info("==================" + Thread.currentThread().getName());
    log.info("================== 开始时间:" + TimeUtil.getLocalDateTime());
    Long datacenterId = machine.getDatacenterId();
    Long workerId = machine.getWorkerId();
    String createBy = "machine-" +
            datacenterId +
            "-" +
            workerId;
    SnowFlakeIdUtil idUtil = new SnowFlakeIdUtil(workerId, datacenterId);

    List userList = CopyUtil.copyList(userBOList, User.class);

    userList.forEach(user -> {
        user.setUserId(idUtil.genNextId());
        user.setCreateTime(TimeUtil.getLocalDateTime());
        user.setUpdateTime(TimeUtil.getLocalDateTime());
        user.setCreateBy(createBy);
    });

    userMapper.insertForBatch(userList);

    try {
        Thread.sleep(10000);
    } catch (InterruptedException e) {
        e.printStackTrace();
    }

    log.info("================== 结束时间:" + TimeUtil.getLocalDateTime());
}

controller层:

@PostMapping("/insert/batch")
public ResultModel insertBatch() throws CommonException {
    List userBOList = new ArrayList<>();

    for (int i = 200 ; i < 300 ; i++) {
        UserBO userBO = new UserBO();
        userBO.setUserName("张三" + i);
        userBO.setPassword("123456");
        userBOList.add(userBO);
    }

    userService.insertUserForBatch(userBOList);
    ResultModel resultModel = new ResultModel<>();
    resultModel.setTime(TimeUtil.getNowTime());
    return resultModel;
}

测试:

日志:

当发生阻塞的时候主线程直接返回了,没有造成项目阻塞等待。

老方法:使用jdk线程池类实现


封装工具类:

@Slf4j
@Component
public class ThreadService {

    @Autowired
    private ThreadParamModel threadParamModel;

    // 线程池执行器
    private volatile ThreadPoolExecutor executor;

    // 私有化构造子,阻止外部直接实例化对象
    private ThreadService() {}

    /**
     * 获取单例的线程池对象--单例的双重校验
     *
     * @return 线程池
     */
    public ThreadPoolExecutor getThreadPool() {
        if (executor == null) {
            synchronized (ThreadService.class) {
                if (executor == null) {
                    // 获取处理器数量
                    //int cpuNum = Runtime.getRuntime().availableProcessors();
                    // 根据cpu数量,计算出合理的线程并发数
                    //int maximumPoolSize = cpuNum * 2 + 1;

                    executor = new ThreadPoolExecutor(
                            // 核心线程数
                            threadParamModel.getCorePoolSize(),
                            // 最大线程数
                            threadParamModel.getMaxPoolSize(),
                            // 活跃时间
                            threadParamModel.getKeepAliveSeconds(),
                            // 活跃时间单位
                            TimeUnit.SECONDS,
                            // 线程队列
                            new LinkedBlockingDeque<>(threadParamModel.getQueueCapacity()),
                            // 线程工厂
                            Executors.defaultThreadFactory(),
                            // 队列已满,而且当前线程数已经超过最大线程数时的异常处理策略(这里可以自定义拒绝策略)
                            new ThreadPoolExecutor.AbortPolicy() {
                                @Override
                                public void rejectedExecution(Runnable r, ThreadPoolExecutor e) {
                                    log.warn("线程等待队列已满,当前运行线程总数:{},活动线程数:{},等待运行任务数:{}",
                                            e.getPoolSize(),
                                            e.getActiveCount(),
                                            e.getQueue().size());
                                }
                            }
                    );

                    executor.allowCoreThreadTimeOut(false);
                }
            }
        }
        return executor;
    }

    /**
     * 向线程池提交一个任务,返回线程结果
     *
     * @param callable 任务
     * @return 处理结果
     */
    public  Future submit(Callable callable) {
        return getThreadPool().submit(callable);
    }

    /**
     * 向线程池提交一个任务,不关心处理结果
     *
     * @param runnable 任务
     */
    public void execute(Runnable runnable) {
        if (runnable == null) throw new NullPointerException();
        getThreadPool().execute(runnable);
    }

    /**
     * 获取当前线程池线程数量
     */
    public int getSize() {
        return getThreadPool().getPoolSize();
    }

    /**
     * 获取当前活动的线程数量
     */
    public int getActiveCount() {
        return getThreadPool().getActiveCount();
    }

    /**
     * 从线程队列中移除对象
     */
    public void cancel(Runnable runnable) {
        if (executor != null) {
            getThreadPool().getQueue().remove(runnable);
        }
    }
}

业务方法调用:

@Autowired
private ThreadService threadService;

@Override
public void insertUserForBatch(List userBOList) throws CommonException {
    log.info("==================" + Thread.currentThread().getName());
    log.info("================== 开始时间:" + TimeUtil.getLocalDateTime());
    Long datacenterId = machine.getDatacenterId();
    Long workerId = machine.getWorkerId();
    String createBy = "machine-" +
            datacenterId +
            "-" +
            workerId;
    SnowFlakeIdUtil idUtil = new SnowFlakeIdUtil(workerId, datacenterId);

    List userList = CopyUtil.copyList(userBOList, User.class);

    userList.forEach(user -> {
        user.setUserId(idUtil.genNextId());
        user.setCreateTime(TimeUtil.getLocalDateTime());
        user.setUpdateTime(TimeUtil.getLocalDateTime());
        user.setCreateBy(createBy);
    });

    userMapper.insertForBatch(userList);

    threadService.execute(() -> {
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        log.info("================== 结束时间:" + TimeUtil.getLocalDateTime());
    });

}

测试:

获取异步线程池的返回值


使用service层Future获取返回值:

@Async(value = "asyncExecutor")
@Override
public Future executeAsync() {
    log.info("==================" + Thread.currentThread().getName());
    LocalDateTime start = TimeUtil.getLocalDateTime();
    log.info("================== 开始时间:" + start);

    try {
        Thread.sleep(10000);
    } catch (InterruptedException e) {
        e.printStackTrace();
    }

    LocalDateTime end = TimeUtil.getLocalDateTime();
    long time = ChronoUnit.SECONDS.between(start, end);
    log.info("================== 结束时间:" + end);

    return new AsyncResult<>(String.valueOf(time));
}

controller层:

@PostMapping("/r/r2")
public ResultModel api2() throws CommonException, ExecutionException, InterruptedException {
    ResultModel resultModel = new ResultModel<>();
    Future future = userService.executeAsync();
    resultModel.setTime(TimeUtil.getNowTime());
    resultModel.setMsg("执行成功");
    resultModel.setRes(future.get());
    return resultModel;
}

测试:

异步多线程组合


service层:

@Async(value = "asyncExecutor")
@Override
public CompletableFuture executeCompletableAsync() {
    LocalDateTime start = TimeUtil.getLocalDateTime();
    log.info("================== 开始时间:" + start);

    try {
        Thread.sleep(10000);
    } catch (InterruptedException e) {
        e.printStackTrace();
    }

    LocalDateTime end = TimeUtil.getLocalDateTime();
    long time = ChronoUnit.SECONDS.between(start, end);
    log.info("================== 结束时间:" + end);

    return CompletableFuture.completedFuture(String.valueOf(time));
}

controller层:

@PostMapping("/r/r2")
public ResultModel api2() throws CommonException, ExecutionException, InterruptedException {
    ResultModel resultModel = new ResultModel<>();
    CompletableFuture future1 = userService.executeCompletableAsync();
    CompletableFuture future2 = userService.executeCompletableAsync();
    CompletableFuture future3 = userService.executeCompletableAsync();
    CompletableFuture.allOf(future1, future2, future3);

    resultModel.setTime(TimeUtil.getNowTime());
    resultModel.setMsg("执行成功");
    int res = Integer.parseInt(future1.get()) + Integer.parseInt(future2.get()) + Integer.parseInt(future3.get());
    resultModel.setRes(String.valueOf(res));
    return resultModel;
}

测试:

多线程访问同一个对象


还是搬出我多年前学多线程的小例子,卖火车票。我假定我有50张票,然后我找了200个人来抢票。

@Slf4j
@Service
public class AsyncService {

    @Autowired
    private ThreadService threadService;

    private Integer ai = 50;

    @Async(value = "asyncExecutor")
    public void executeAsync() {

        if (ai > 0) {
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

            log.info(Thread.currentThread().getName() + "正在出售第" + ai + "张票");
            ai--;
        }
    }
}

结果:

查看日志文件发现,票卖了54次,而且日志也可以看到有些票被重复售卖,已经超卖了。

解决多线程访问同一个对象问题,我们可以采用锁来实现:

方式一:利用cas方式实现的原子类

@Slf4j
@Service
public class AsyncService {

    @Autowired
    private ThreadService threadService;

    private AtomicInteger ai = new AtomicInteger(50);

    @Async(value = "asyncExecutor")
    public void executeAsync() {

        if (ai.get() > 0) {
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

            log.info(Thread.currentThread().getName() + "正在出售第" + (ai.decrementAndGet() + 1) + "张票");
        }
    }
}

方式二:使用synchronized类给代码块上锁

@Slf4j
@Service
public class AsyncService {

    @Autowired
    private ThreadService threadService;

    //private AtomicInteger ai = new AtomicInteger(50);

    private Integer ai = 50;

    @Async(value = "asyncExecutor")
    public void executeAsync() {

        synchronized (this) {
            if (ai > 0) {
                try {
                    Thread.sleep(2000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }

                log.info(Thread.currentThread().getName() + "正在出售第" + ai + "张票");
                ai--;
            }
        }
    }
}

方式三:

@Slf4j
@Service
public class AsyncService {

    @Autowired
    private ThreadService threadService;

    //private AtomicInteger ai = new AtomicInteger(50);

    private Integer ai = 50;

    private ReentrantLock rl = new ReentrantLock();

    @Async(value = "asyncExecutor")
    public void executeAsync() {
        rl.lock();
        if (ai > 0) {
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

            log.info(Thread.currentThread().getName() + "正在出售第" + ai + "张票");
            ai--;
        }
        rl.unlock();
    }
}

思考


  • 如果面对的是读多写少的情况,那么我们可以采用ReentrantReadWriteLock锁。
  • 如果多台机器之间,需要保证操作一致,可以使用分布式锁。

线程池场景设计

  • 排队延时线程池
  • 线程池实现生产者消费者模型
  • 多线程延时线程池
  • 线程池处理网络请求

利用线程池实现批量下载任务

下载单个任务的工具类:

@Slf4j
public class DownloadUtil {

    public static void download(String fileUrl, String path) throws Exception {
        HttpURLConnection conn = null;
        InputStream in = null;
        FileOutputStream out = null;
        try {
            URL url = new URL(fileUrl);
            conn = (HttpURLConnection) url.openConnection();
            conn.setRequestMethod("GET");
            conn.setConnectTimeout(10000);
            in = conn.getInputStream();
            out = new FileOutputStream(path);
            int len;
            byte[] arr = new byte[1024 * 1000];
            while (-1 != (len = in.read(arr))) {
                out.write(arr, 0, len);
            }
            out.flush();
        } catch (Exception e) {
            log.error("Fail to download: {} by {}", fileUrl, e.getMessage());
            throw new Exception(e.getMessage());
        } finally {
            try {
                if(null != conn){
                    conn.disconnect();
                }
                if (null != out) {
                    out.close();
                }
                if (null != in) {
                    in.close();
                }
            } catch (Exception e) {
                log.error("Error to close stream: {}", e.getMessage());
                throw new Exception(e.getMessage());
            }
        }
    }
}

线程配置模型:

@ToString
@Data
@Component
@ConfigurationProperties(prefix = "thread.param")
public class ThreadParamModel {

    /**
     * 核心线程数量,默认1
     */
    private int corePoolSize = 3;

    /**
     * 最大线程数量,默认Integer.MAX_VALUE;
     */
    private int maxPoolSize = 5;

    /**
     * 空闲线程存活时间
     */
    private int keepAliveSeconds = 60;

    /**
     * 线程阻塞队列容量,默认Integer.MAX_VALUE
     */
    private int queueCapacity = 1;

    /**
     * 是否允许核心线程超时
     */
    private boolean allowCoreThreadTimeOut = false;
}

线程服务类:

@Slf4j
@Component
public class ThreadService {


    @Autowired
    private ThreadParamModel threadParamModel;

    // 线程池执行器
    private volatile ThreadPoolExecutor executor;

    // 私有化构造子,阻止外部直接实例化对象
    private ThreadService() {}

    /**
     * 获取单例的线程池对象--单例的双重校验
     *
     * @return 线程池
     */
    public ThreadPoolExecutor getThreadPool() {
        if (executor == null) {
            synchronized (ThreadService.class) {
                if (executor == null) {
                    // 获取处理器数量
                    //int cpuNum = Runtime.getRuntime().availableProcessors();
                    // 根据cpu数量,计算出合理的线程并发数
                    //int maximumPoolSize = cpuNum * 2 + 1;

                    executor = new ThreadPoolExecutor(
                            // 核心线程数
                            threadParamModel.getCorePoolSize(),
                            // 最大线程数
                            threadParamModel.getMaxPoolSize(),
                            // 活跃时间
                            threadParamModel.getKeepAliveSeconds(),
                            // 活跃时间单位
                            TimeUnit.SECONDS,
                            // 线程队列
                            new LinkedBlockingDeque<>(threadParamModel.getQueueCapacity()),
                            // 线程工厂
                            Executors.defaultThreadFactory(),
                            // 队列已满,而且当前线程数已经超过最大线程数时的异常处理策略(这里可以自定义拒绝策略)
                            new ThreadPoolExecutor.AbortPolicy() {
                                @Override
                                public void rejectedExecution(Runnable r, ThreadPoolExecutor e) {
                                    log.warn("线程等待队列已满,当前运行线程总数:{},活动线程数:{},等待运行任务数:{}",
                                            e.getPoolSize(),
                                            e.getActiveCount(),
                                            e.getQueue().size());
                                }
                            }
                    );

                    executor.allowCoreThreadTimeOut(false);
                }
            }
        }
        return executor;
    }

    /**
     * 向线程池提交一个任务,返回线程结果
     *
     * @param callable 任务
     * @return 处理结果
     */
    public  Future submit(Callable callable) {
        return getThreadPool().submit(callable);
    }

    /**
     * 向线程池提交一个任务,不关心处理结果
     *
     * @param runnable 任务
     */
    public void execute(Runnable runnable) {
        if (runnable == null) throw new NullPointerException();
        getThreadPool().execute(runnable);
    }

    /**
     * 获取当前线程池线程数量
     */
    public int getSize() {
        return getThreadPool().getPoolSize();
    }

    /**
     * 获取当前活动的线程数量
     */
    public int getActiveCount() {
        return getThreadPool().getActiveCount();
    }

    /**
     * 从线程队列中移除对象
     */
    public void cancel(Runnable runnable) {
        if (executor != null) {
            getThreadPool().getQueue().remove(runnable);
        }
    }
}

任务实现类:

@Data
public class ThreadRunner implements Runnable{

    private String fileUrl;

    private String fileNumber;

    @Override
    public void run() {
        try {
            DownloadUtil.download(fileUrl, "E:\\resource\\jedis-2.9.0-" + fileNumber+ ".jar");
        } catch (Exception e) {
            e.printStackTrace();
        }

    }
}

@Data
public class ThreadCaller implements Callable {

    private String fileUrl;

    private String fileNumber;

    @Override
    public String call() throws Exception {
        try {
            DownloadUtil.download(fileUrl, "E:\\resource\\jedis-2.9.0-" + fileNumber+ ".jar");
            return "当前已经下载成功!";
        } catch (Exception e) {
            e.printStackTrace();
        }

        return "";
    }
}

服务层下载逻辑:

@Slf4j
@Service
public class DownloadService {

    @Autowired
    private ThreadService threadService;

    /**
     * 单线程下载多个文件无需返回值
     * @param fileUrl
     */
    public void download(String fileUrl) {
        try {
            for (int i = 0 ; i < 100 ; i++) {
                DownloadUtil.download(fileUrl, "E:\\resource\\jedis-2.9.0-" + i+ ".jar");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 多线程下载多个文件无需返回值
     * @param fileUrl
     */
    public void downloadByThread(String fileUrl) {
        try {
            for (int i = 0 ; i < 100 ; i++) {
                ThreadRunner threadRunner = new ThreadRunner();
                threadRunner.setFileUrl(fileUrl);
                threadRunner.setFileNumber(String.valueOf(i));
                threadService.execute(threadRunner);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 多线程下载多个文件需要返回值
     * @param fileUrl
     */
    public void downloadByThreadForRet(String fileUrl) {
        try {
            for (int i = 0 ; i < 100 ; i++) {
                ThreadCaller threadCaller = new ThreadCaller();
                threadCaller.setFileUrl(fileUrl);
                threadCaller.setFileNumber(String.valueOf(i));
                Future future = threadService.submit(threadCaller);
                log.info("-----------fileNumber = {},  result = {}", String.valueOf(i), future.get());
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

控制层实现:

@Slf4j
@RequestMapping("/download")
@RestController
public class DownloadController {

    @Autowired
    private DownloadService downloadService;

    @GetMapping("/downloadFile")
    public void downloadFile() {
        StopWatch stopWatch = new StopWatch();
        // 开始时间
        stopWatch.start();
        String fileUrl = "https://developer.aliyun.com/mvn/view/central/#Redis/clients/jedis/2.9.0/jedis-2.9.0.jar";
        downloadService.download(fileUrl);
        // 结束时间
        stopWatch.stop();

        // 统计执行时间(毫秒)
        log.info("----/download/downloadFile---------- 执行时长:{}", stopWatch.getTotalTimeMillis() + "毫秒");
    }

    @GetMapping("/downloadByThread")
    public void downloadByThread() {
        StopWatch stopWatch = new StopWatch();
        // 开始时间
        stopWatch.start();
        String fileUrl = "https://developer.aliyun.com/mvn/view/central/#Redis/clients/jedis/2.9.0/jedis-2.9.0.jar";
        downloadService.downloadByThread(fileUrl);
        // 结束时间
        stopWatch.stop();

        // 统计执行时间(毫秒)
        log.info("----/download/downloadByThread---------- 执行时长:{}", stopWatch.getTotalTimeMillis() + "毫秒");
    }


    @GetMapping("/downloadByThreadForRet")
    public void downloadByThreadForRet() {
        StopWatch stopWatch = new StopWatch();
        // 开始时间
        stopWatch.start();
        String fileUrl = "https://developer.aliyun.com/mvn/view/central/#Redis/clients/jedis/2.9.0/jedis-2.9.0.jar";
        downloadService.downloadByThreadForRet(fileUrl);
        // 结束时间
        stopWatch.stop();

        // 统计执行时间(毫秒)
        log.info("----/download/downloadByThreadForRet---------- 执行时长:{}", stopWatch.getTotalTimeMillis() + "毫秒");
    }
}

思考

  1. 上述线程池实现批量下载任务。如果单个线程下载任务时,出现了一个大文件,那么我们如何下载呢?如果直接下载缓存到数组中,容易出现java内存溢出问题,方案:文件切片,多线程分段下载、保存、合并(后续有时间,在本篇博客继续更新代码)。
  2. 上述当前线程执行的过程没有跟踪,只是在执行完毕时,才发现。答案:重写ThreadPoolExecutor类。

你可能感兴趣的:(springboot之线程池)