【C++】实现线程池

整理自《【C++】《C++ 并发编程实战 (第2版) 》笔记-Chapter9-高级线程管理》

1、线程池类,包括 AJoinThreads、AQueueThreadSafe、ATaskType、AStealTaskQueue、AMyThreadPool。其中,

AJoinThreads:RAII思想,析构时调用 join() 自动等待所有线程完成。

AQueueThreadSafe:基于 std::mutex 的线程安全的 queue,用于存放任务。

ATaskType:任务包装类。

AStealTaskQueue:基于 std::mutex 的线程安全的 deque,用于存放各个线程自己的任务,是一个可被其他线程窃取任务的队列。

AMyThreadPool:对外使用的主类,外面通过 submit() 函数提交任务给线程池并得到对应任务的 future。

#ifndef AMYTHREADPOOL_H
#define AMYTHREADPOOL_H

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

// RAII思想,析构时调用 join() 自动等待所有线程完成。
class AJoinThreads
{
public:
    explicit AJoinThreads(std::vector &threads)
        : m_threads(threads)
    {}

    ~AJoinThreads()
    {
        for(std::thread& t : m_threads)
        {
            if(t.joinable())
                t.join();
        }
    }

private:
    std::vector &m_threads;
};

// 基于 std::mutex 的线程安全的 queue,用于存放任务。
template
class AQueueThreadSafe
{
public:
    AQueueThreadSafe()
    {}

    AQueueThreadSafe(const AQueueThreadSafe &other)
    {
        std::lock_guard lk(other.m_mtx);
        m_queueData = other.m_queueData;
    }

    void push(T newValue)
    {
        std::lock_guard lk(m_mtx);
        m_queueData.push(std::move(newValue));
        m_cond.notify_one();
    }

    void wait_and_pop(T &value)
    {
        std::unique_lock lk(m_mtx);
        m_cond.wait(lk, [this]{ return false == m_queueData.empty(); });
        value = m_queueData.front();
        m_queueData.pop();
    }

    std::shared_ptr wait_and_pop()
    {
        std::unique_lock lk(m_mtx);
        m_cond.wait(lk, [this]{ return false == m_queueData.empty(); });
        std::shared_ptr res(std::make_shared(m_queueData.front()));
        m_queueData.pop();
        return res;
    }

    bool try_pop(T &value)
    {
        std::lock_guard lk(m_mtx);
        if(m_queueData.empty())
            return false;
        value = std::move(m_queueData.front());
        m_queueData.pop();
        return true;
    }

    std::shared_ptr try_pop()
    {
        std::lock_guard lk(m_mtx);
        if(m_queueData.empty())
            return std::shared_ptr();
        std::shared_ptr res(std::make_shared(m_queueData.front()));
        m_queueData.pop();
        return res;
    }

    bool empty() const
    {
        std::lock_guard lk(m_mtx);
        return m_queueData.empty();
    }

private:
    mutable std::mutex m_mtx;     // 互斥必须用 mutable 修饰(针对 const 对象,准许其数据成员发生变动)
    std::queue m_queueData;
    std::condition_variable m_cond;
};

// 任务包装类。
// 由于 std::packaged_task<> 的实例仅能移动而不可复制,但 std::function<> 要求本身所含的函数对象能进行拷贝构造,因此任务队列的元素不能用 std::function<> 充当。
// 我们必须定制自己的类作为代替,以包装函数,并处理只移型别。
// 这个类其实就是一个包装可调用对象的简单类,对外可消除该对象的型别。这个类还具备函数调用操作符。
// 实现基类 ATaskTypeImplBase 是有必要的,不然 m_impl 就必须为 std::unique_ptr>,这就需要类 ATaskType 是个模板类,这违背了“对外可消除该对象的型别”的初衷。
class ATaskType
{
    struct ATaskTypeImplBase
    {
        virtual ~ATaskTypeImplBase() {}
        virtual void call() {}
    };

    template
    struct ATaskTypeImpl : public ATaskTypeImplBase
    {
        F f;
        ATaskTypeImpl(F&& f_)
            : f(std::move(f_))
        {}

        void call()
        {
            f();
        }
    };

public:
    template
    ATaskType(F&& f)
        : m_impl(new ATaskTypeImpl(std::move(f)))
    {}

    void operator() ()
    {
        m_impl->call();
    }

    ATaskType() = default;
    ATaskType(ATaskType&& other)
        : m_impl(std::move(other.m_impl))
    {}

    ATaskType& operator=(ATaskType&& other)
    {
        m_impl = std::move(other.m_impl);
        return *this;
    }

private:
    ATaskType(const ATaskType&) = delete;
    ATaskType(ATaskType&) = delete;
    ATaskType& operator=(const ATaskType&) = delete;

private:
    std::unique_ptr m_impl;
};

// 基于 std::mutex 的线程安全的 deque,用于存放各个线程自己的任务,是一个可被其他线程窃取任务的队列。
class AStealTaskQueue
{
public:
    AStealTaskQueue()
    {}
    AStealTaskQueue(const AStealTaskQueue& other) = delete;
    AStealTaskQueue& operator=(const AStealTaskQueue& other) = delete;

    void push(ATaskType data)
    {
        std::lock_guard lk(m_mtx);
        m_dequeTask.push_front(std::move(data));
    }

    bool empty() const
    {
        std::lock_guard lk(m_mtx);
        return m_dequeTask.empty();
    }

    bool try_pop(ATaskType &res)
    {
        std::lock_guard lk(m_mtx);
        if(m_dequeTask.empty())
        {
            return false;
        }
        res = std::move(m_dequeTask.front());
        m_dequeTask.pop_front();
        return true;
    }

    bool try_steal(ATaskType &res)      // 窃取时,操作的是队列后端
    {
        std::lock_guard lk(m_mtx);
        if(m_dequeTask.empty())
            return false;
        res = std::move(m_dequeTask.back());
        m_dequeTask.pop_back();
        return true;
    }

private:
    std::deque m_dequeTask;
    mutable std::mutex m_mtx;
};

// 对外使用的主类,外面通过 submit() 函数提交任务给线程池并得到对应任务的 future。
class AMyThreadPool
{
public:
    AMyThreadPool()
        : m_bDone(false)
        , m_joiner(m_threads)
    {
        unsigned const nThreadCnt = std::thread::hardware_concurrency();
        try
        {
            // 构造每个线程自己的可被别的线程窃取的任务队列
            for(unsigned i=0; i(new AStealTaskQueue));
            }

            // 构造线程池中每个线程,并向线程函数传递参数 i 作为线程下标
            for(unsigned i=0; i
    std::future::type> submit(FunctionType f)
    {
        typedef typename std::result_of::type result_type;
        std::packaged_task task(f);
        std::future res(task.get_future());
        if(tl_pLocalTaskQueue)
        {
            tl_pLocalTaskQueue->push(std::move(task));
        }
        else
        {
            m_queueTask.push(std::move(task));
        }
        return res;
    }

protected:
    void workerThread(unsigned nCurThreadIndex)
    {
        tl_nCurThreadIndex = nCurThreadIndex;
        tl_pLocalTaskQueue = m_vecTaskQueue[nCurThreadIndex].get();
        while(false == m_bDone)
        {
            runTask();
        }
    }

    void runTask()
    {
        ATaskType task;
        if(getTaskFromLocalThreadQueue(task)
                || getTaskFromTotalQueue(task)
                || getTaskFromOtherThreadQueue(task))
        {
            task();
        }
        else
        {
            std::this_thread::yield();
        }
    }

    bool getTaskFromLocalThreadQueue(ATaskType &task)
    {
        return tl_pLocalTaskQueue && tl_pLocalTaskQueue->try_pop(task);
    }

    bool getTaskFromTotalQueue(ATaskType &task)
    {
        return m_queueTask.try_pop(task);
    }

    bool getTaskFromOtherThreadQueue(ATaskType &task)
    {
        for(unsigned i=0; itry_steal(task))
                return true;
        }
        return false;
    }

private:
    std::vector m_threads;                             // 用 vector 存放所有线程
    std::atomic_bool m_bDone;                                       // 停止线程工作的标记
    AJoinThreads m_joiner;                                          // 析构时等待所有线程完成
    AQueueThreadSafe m_queueTask;                        // 总的任务队列
    std::vector > m_vecTaskQueue;  // 每个线程自己的任务队列

    static thread_local AStealTaskQueue* tl_pLocalTaskQueue;        // 当前线程的任务队列
    static thread_local unsigned tl_nCurThreadIndex;                // 当前线程的下标
};
thread_local AStealTaskQueue* AMyThreadPool::tl_pLocalTaskQueue;
thread_local unsigned AMyThreadPool::tl_nCurThreadIndex;

#endif // AMYTHREADPOOL_H

2、测试代码

#include "threadpool/amythreadpool.h"
#include 
#include 
#include 

namespace test3
{
    void testMyThreadPool(int nHandleWay)
    {
        AMyThreadPool threadPool;
        auto start = std::chrono::steady_clock::now();

        for(int i=0; i<10; ++i)
        {
            auto f = threadPool.submit([=]()->int
            {
                std::this_thread::sleep_for(std::chrono::milliseconds(100));
                return i;
            });
            if(1 == nHandleWay)         // costMs:2009
            {
                std::this_thread::sleep_for(std::chrono::milliseconds(200));
                std::cout << f.get() << std::endl;
            }
            else if(2 == nHandleWay)    // costMs:1009
                f.get();
            else if(3 == nHandleWay)    // costMs:1009
                std::cout << f.get() << std::endl;
            else if(4 == nHandleWay)    // costMs:0
            {}
        }

        int costMs = std::chrono::duration_cast(std::chrono::steady_clock::now()- start).count();
        std::cout << std::thread::hardware_concurrency() << " costMs:" << costMs << std::endl;
    }
}

int main(int argc, char *argv[])
{
    std::cout << "Test Begin...\n";

    // -----------------------------------------------------------------

    test3::testMyThreadPool(1);
    std::cout << "----------------------------\n";
    test3::testMyThreadPool(2);
    std::cout << "----------------------------\n";
    test3::testMyThreadPool(3);
    std::cout << "----------------------------\n";
    test3::testMyThreadPool(4);

    // -----------------------------------------------------------------

    std::cout << "Test End...\n";

    return 0;
}

你可能感兴趣的:(【C++】实战,c++)