基于C++11/14/17的线程池实现

线程池,顾名思义就是预先启动一些线程,集中管理,需要的时候直接拿来用,无需用时再创建。尤其是在Windows平台,线程是稀缺资源,线程的创建和销毁都是十分耗时的,所以利用线程池来提升并发场景下的性能,是十分有必要的。

C++11首次对并发进行了支持,这使得我们利用STL编写多线程应用程序成为了可能,不过STLthread比较简陋,并没有提供更多的强大特性,只是一个最基础的多线程解决方案,定位应该是尽可能面向更多应用场景的最通用的版本。

先上代码:Thread Pool
为了拥抱新标准,尽量使用了C++14/17的特性,所以如果需要编译运行这些代码的话,需要启用你的编译器的c++17特性。

本文不对thread,mutex,condition_variable作过多解释,不知道的同学可以去查阅相关资料 ——> cpp_reference。

首先我们定义设计ThreadPool这个类的接口:

class ThreadPool {
    public:
        explicit ThreadPool(const size_t& max_threads);

        template<class Func, typename... Args>
        decltype(auto) submitTask(Func&& func, Args&&... args);
        void pause();
        void unpause();
        void close();
        bool isClosed() const;
        ~ThreadPool();
    protected:
        void _scheduler();
        void _launchNew();
    private:
        static size_t core_thread_count;
        size_t max_thread_count;
        // thread-manager
        std::vector<std::thread> threads;
        // tasks-queue
        std::queue<std::function<void()>> tasks;
        // for synchronization
        std::mutex queue_mtx;
        std::mutex pause_mtx;
        std::condition_variable cond_var;
        bool paused;
        bool closed;
    };

我参考了部分Java中线程池的设计方式,有core_thread_count来指定最先启动的基础工作线程数,max_thread_count来指定最多运行的工作线程数,通常被传入构造函数的整数max_threads指定,如果max_threadscore_thread_count还小,那么指定max_thread_count=core_thread_count.以下是完整的构造函数:

thread_pool::ThreadPool::ThreadPool(const size_t & max_threads)
    : closed(false), paused(false), max_thread_count(max_threads)
{
    if (max_threads <= 0) {
        max_thread_count = core_thread_count;
        throw std::runtime_error("Invalid thread-number passed in.");
    }
    size_t t_count = core_thread_count;
    if (max_threads < core_thread_count) {
        max_thread_count = core_thread_count;
        t_count = max_threads;
    }
    // launch some threads firstly.
    for (size_t i = 0; i < t_count; ++i) {
        _launchNew();
    }
    // lanuch sheduler and running background.
    std::thread scheduler = std::thread(&ThreadPool::_scheduler, this);
    scheduler.detach();
}

接下来解析两个关键的函数:_scheduler_launchNew
_launchNew函数用于启动一个工作线程,具体实现如下:

void thread_pool::ThreadPool::_launchNew()
{
    if (threads.size() < max_thread_count) {
        threads.emplace_back([this] {
            while (true) {
                if (this->paused) {
                    std::unique_lock<std::mutex> pause_lock(this->pause_mtx);
                    cond_var.wait(pause_lock, [this] {
                        return !this->paused;
                    });
                }
                std::function<void()> task;
                {
                    std::unique_lock<std::mutex> lock(this->queue_mtx);
                    cond_var.wait(lock, [this] {
                        return this->closed || !this->tasks.empty();  // trigger when close or new task comes.
                    });
                    if (this->closed)  // exit when close.
                        return;
                    task = std::move(this->tasks.front());
                    this->tasks.pop();
                }
                task();  // execute task.
            }
        }
        );
    }
}

最核心的是while(true)体内的部分。
这是一个无限循环的函数体,除非线程池被关闭,否则它会一直的运行,首先通过condition_variable::wait使其进入休眠,当队列中有新任务加入时,会有condition_variable::notify_one唤醒此线程,然后它从中取出队首的任务,此处是一个void()类型的可调用对象(经过lambda包装,后面会讲),即function,取出后执行任务,执行结束继续进入休眠状态(也可能会遇到线程池暂停,具体的可以自己看代码)。

_scheduler()是线程的调度函数,实现如下:

void thread_pool::ThreadPool::_scheduler()
{
    // find new task and notify one free thread to execute.
    while (!this->closed) {  // auto-exit when close.
        if (this->paused) {
            std::unique_lock<std::mutex> pause_lock(this->pause_mtx);
            cond_var.wait(pause_lock, [this] {
                return !this->paused;
            });
        }

        if (tasks.empty() ||
            tasks.size() > max_thread_count)  // if tasks-size > max_threads , just loop for waiting.
            continue;
        else if (tasks.size() <= threads.size())
            cond_var.notify_one();
        else if (tasks.size() < max_thread_count) {
            _launchNew();
            cond_var.notify_one();
        }
    }
}

也是一个无限循环的while结构,当队列中有任务时,根据任务数量,通知已有的线程去取任务执行,或者增加线程数量,继续通知。

最最重要的是submitTask()这个函数,它是整个线程池的核心,也是任务提交到线程池内的唯一接口,这里用到了decltype(auto),init capture,std::apply,std::make_tupleC++14/17新特性,当然还有c++11加入的universal reference等,就不一一赘述了,先上具体实现:

template<class Func, typename... Args>
    inline decltype(auto)
        ThreadPool::submitTask(Func&& func, Args&&...args)
    {
            auto task = std::async(std::launch::async,
                std::forward(func), std::forward(args)...);
            return task;
        }

        using return_type = typename std::result_of_t(Args...)>;

        auto task = std::make_shared::packaged_task()>>(
            [func = std::forward(func),
            args = std::make_tuple(std::forward(args)...)]()->return_type{
            return std::apply(func, args);
        }
        );

        auto fut = task->get_future();
        {
            std::lock_guard::mutex> lock(this->queue_mtx);
            if (this->closed || this->paused)
                throw std::runtime_error("Do not allow executing tasks after closed or paused.");
            tasks.emplace([=]() {  // `=` mode instead of `&` to avoid ref-dangle.
                (*task)();
            });
        }
        return fut;
    }
}

第一个参数是一个可调用对象,后面的变长模板参数是这个可调用对象的参数,使用通用引用配合std::forward就可以达到完美转发任意函数及其参数的作用。因为我们在执行任务的时候,调用的统一是std::function对象,所以我们要把真正的函数用一个无形参无返回值的lambda包装,而任务的执行结果用std::future可以得到,所以我们先构建std::packaged_task,然后通过它得到该任务的future对象,它用于最后的返回,我们可以通过对返回后的它调用.get()得到任务的运行结果,包装好的packaged_task则用无返回值无形参的lambda包装之后,加入队列。

其它的暂停/关闭这些小接口的实现就自己看源码吧~
然后做了个简单的性能测试(代码在example.cpp,经人提醒去掉了一些可能会影响结果的噪声):

  • 创建线程 -> 执行任务 -> 销毁线程
  • 创建线程池 -> 把任务丢到线程池里

VS-Run模式下,性能差距非常明显:
如图:
基于C++11/14/17的线程池实现_第1张图片

Windows平台下使用clang++编译运行,两者的性能差距与上图接近。

Linux下使用最新的clang++ 5.0除了需要添加-std=c++17,还需要添加-lpthread参数,但是我编译好运行的时候直接被bash给杀了…也是很无奈。

差不多就是这样,一个积极拥抱新标准的线程池实现,后续准备将std::queue用无锁队列替换,可以进一步提高性能。

觉得有帮助的可以给个Star.

你可能感兴趣的:(C++)