线程池代码分析及延申应用(续二)

代码延申应用

8. 线程池与分布式系统的结合
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

class MthreadPool {
public:
    MthreadPool(int min, int max) : minthread(min), maxthread(max),
        stopthread(false), idlethread(min), currentthread(min),
        startTime(std::chrono::steady_clock::now())
    {
        manager = new std::thread(&MthreadPool::Manager, this);
        counttask = 0;
        addcount = 0;
        totalTaskTime = std::chrono::milliseconds(0);
        taskQueues.resize(max);
        taskQueueConditions.resize(max);
        for (int i = 0; i < max; i++)
        {
            std::thread t(&MthreadPool::Worker, this, i);
            workers.insert(std::make_pair(t.get_id(), std::move(t)));
        }
    }

    ~MthreadPool()
    {
        stopthread = true;
        for (auto& cond : taskQueueConditions) {
            cond.notify_all();
        }
        for (auto& it : workers)
        {
            if (it.second.joinable())
            {
                it.second.join();
            }
        }
        if (manager->joinable())
        {
            manager->join();
        }
        delete manager;
        auto endTime = std::chrono::steady_clock::now();
        activeTime = std::chrono::duration_cast(endTime - startTime);
        logStatistics();
    }

    template
    auto AddTask(F&& f, Args&&... args) 
        -> std::future::type> {
        using return_type = typename std::result_of::type;

        auto task = std::make_shared< std::packaged_task >(
            std::bind(std::forward(f), std::forward(args)...)
        );
        
        std::future res = task->get_future();
        {
            std::lock_guard lock(queueMutex);
            if (stopthread)
                throw std::runtime_error("enqueue on stopped ThreadPool");
            int queueIndex = nextQueueIndex++;
            queueIndex %= maxthread;
            taskQueues[queueIndex].emplace([task]() { (*task)(); });
            addcount++;
        }
        taskQueueConditions[queueIndex].notify_one();
        return res;
    }

    void logStatistics() {
        std::ofstream logFile("thread_pool_stats.log", std::ios::app);
        if (logFile.is_open()) {
            logFile << "Active Time: " << activeTime.count() << " ms" << std::endl;
            logFile << "Total Tasks Added: " << addcount.load() << std::endl;
            logFile << "Total Tasks Completed: " << counttask.load() << std::endl;
            if (counttask.load() > 0) {
                auto avgTaskTime = totalTaskTime.load() / counttask.load();
                logFile << "Average Task Execution Time: " << avgTaskTime.count() << " ms" << std::endl;
            }
            logFile.close();
        }
    }

private:
    int minthread;
    int maxthread;
    std::atomic stopthread;
    std::atomic idlethread;
    std::atomic currentthread;
    std::atomic exithread{ 0 };
    std::atomic counttask;
    std::atomic addcount;
    std::map workers;
    std::thread* manager;
    std::vector>> taskQueues;
    std::mutex queueMutex;
    std::mutex exitidMutex;
    std::vector taskQueueConditions;
    std::atomic nextQueueIndex{0};
    std::vector exit_id;
    std::atomic totalQueueSize{0};
    std::atomic totalTaskTime;
    std::chrono::time_point startTime;
    std::chrono::milliseconds activeTime;

    void Manager(void)
    {
        while (!stopthread.load())
        {
            std::this_thread::sleep_for(std::chrono::seconds(3));
            int idle = idlethread.load();
            int cur = currentthread.load();
            int totalSize = 0;
            {
                std::lock_guard lock(queueMutex);
                for (const auto& queue : taskQueues) {
                    totalSize += queue.size();
                }
                totalQueueSize = totalSize;
            }

            if (idle > cur / 2 && cur > minthread && totalSize < cur / 2)
            {
                exithread.store(2);
                for (auto& cond : taskQueueConditions) {
                    cond.notify_all();
                }
                std::lock_guard lock(exitidMutex);
                for (auto id : exit_id)
                {
                    auto thread = workers.find(id);
                    if (thread != workers.end())
                    {
                        (*thread).second.join();
                        workers.erase(thread);
                    }
                }
                exit_id.clear();
            }
            else if (idle == 0 && cur < maxthread && totalSize > cur * 2)
            {
                std::thread t(&MthreadPool::Worker, this, currentthread);
                workers.insert(std::make_pair(t.get_id(), std::move(t)));
                currentthread++;
                idlethread++;
            }
        }
    }

    void Worker(int queueIndex) {
        while (!stopthread.load()) {
            std::function task;
            {
                std::unique_lock lock(queueMutex);
                taskQueueConditions[queueIndex].wait(lock, [this, queueIndex] {
                    return stopthread || !taskQueues[queueIndex].empty();
                });
                if (stopthread && taskQueues[queueIndex].empty())
                    return;
                if (exithread.load() > 0) {
                    currentthread--;
                    exithread--;
                    idlethread--;
                    std::lock_guard lock(exitidMutex);
                    exit_id.emplace_back(std::this_thread::get_id());
                    return;
                }
                task = std::move(taskQueues[queueIndex].front());
                taskQueues[queueIndex].pop();
                totalQueueSize--;
            }
            idlethread--;
            auto taskStartTime = std::chrono::steady_clock::now();
            try {
                task();
                counttask++;
            } catch (const std::exception& e) {
                std::cerr << "Exception caught in worker thread: " << e.what() << std::endl;
            }
            auto taskEndTime = std::chrono::steady_clock::now();
            auto taskDuration = std::chrono::duration_cast(taskEndTime - taskStartTime);
            totalTaskTime += taskDuration;
            idlethread++;
        }
    }
};

// 分布式任务处理类
class DistributedTaskHandler {
public:
    DistributedTaskHandler(MthreadPool& pool, short port) : pool(pool), acceptor(io_context, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), port)), socket(io_context) {
        startAccept();
    }

    void run() {
        io_context.run();
    }

private:
    MthreadPool& pool;
    asio::io_context io_context;
    asio::ip::tcp::acceptor acceptor;
    asio::ip::tcp::socket socket;

    void startAccept() {
        acceptor.async_accept(socket, [this](std::error_code ec) {
            if (!ec) {
                handleRequest();
            }
            startAccept();
        });
    }

    void handleRequest() {
        asio::streambuf buffer;
        asio::async_read_until(socket, buffer, '\n', [this, &buffer](std::error_code ec, std::size_t length) {
            if (!ec) {
                std::istream is(&buffer);
                std::string taskData;
                std::getline(is, taskData);
                // 这里简单模拟任务处理,实际应用中需要根据协议解析任务数据
                pool.AddTask([taskData]() {
                    std::cout << "Processing task: " << taskData << std::endl;
                    // 模拟任务执行
                    std::this_thread::sleep_for(std::chrono::seconds(2));
                });
            }
            socket.close();
        });
    }
};

int main() {
    MthreadPool pool(2, 5);
    DistributedTaskHandler handler(pool, 12345);

    std::thread ioThread([&handler]() {
        handler.run();
    });

    // 主线程可以继续做其他事情
    std::this_thread::sleep_for(std::chrono::seconds(10));

    ioThread.join();
    return 0;
}

解释

  • 分布式任务处理类 DistributedTaskHandler:该类使用了 asio 库来实现网络通信,监听指定端口(这里是 12345),接收远程任务请求。
  • 异步接受连接startAccept 方法使用 async_accept 异步接受新的连接,当有新连接到来时,调用 handleRequest 方法处理请求。
  • 任务处理handleRequest 方法使用 async_read_until 异步读取客户端发送的数据,将其作为任务数据。这里简单模拟了任务处理,实际应用中需要根据具体的协议解析任务数据。然后将任务添加到线程池 pool 中执行。
  • 多线程运行:在 main 函数中,创建一个单独的线程来运行 DistributedTaskHandler 的 run 方法,主线程可以继续执行其他任务。
9. 线程池与机器学习任务的结合

在机器学习领域,很多任务可以并行化处理,例如数据预处理、模型训练等。我们可以使用线程池来加速这些任务的执行。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

// 模拟机器学习数据
struct MLData {
    std::vector features;
    double label;
};

class MthreadPool {
    // 线程池类的定义保持不变,此处省略详细代码
};

// 数据预处理函数
void preprocessData(MLData& data) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(0.1, 0.9);
    for (auto& feature : data.features) {
        feature *= dis(gen);
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(100));
}

// 模型训练函数
void trainModel(const std::vector& data) {
    std::this_thread::sleep_for(std::chrono::seconds(1));
    std::cout << "Model training completed." << std::endl;
}

int main() {
    MthreadPool pool(2, 5);
    std::vector dataset;
    // 生成一些模拟数据
    for (int i = 0; i < 10; ++i) {
        MLData data;
        data.features.resize(10);
        for (int j = 0; j < 10; ++j) {
            data.features[j] = static_cast(i + j);
        }
        data.label = static_cast(i);
        dataset.push_back(data);
    }

    // 并行进行数据预处理
    std::vector> futures;
    for (auto& data : dataset) {
        futures.emplace_back(pool.AddTask(preprocessData, std::ref(data)));
    }

    // 等待所有数据预处理任务完成
    for (auto& future : futures) {
        future.wait();
    }

    // 进行模型训练
    auto trainFuture = pool.AddTask(trainModel, std::ref(dataset));
    trainFuture.wait();

    return 0;
}

解释

  • 模拟机器学习数据:定义了 MLData 结构体来表示机器学习数据,包含特征向量和标签。
  • 数据预处理preprocessData 函数模拟了数据预处理过程,对每个特征进行随机缩放,并休眠一段时间模拟处理耗时。
  • 模型训练trainModel 函数模拟了模型训练过程,休眠一段时间表示训练耗时。
  • 并行处理:在 main 函数中,将数据预处理任务添加到线程池中并行执行,使用 std::future 来管理任务的执行结果。等待所有数据预处理任务完成后,再将模型训练任务添加到线程池中执行。
10. 线程池的安全性增强
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

class MthreadPool {
public:
    MthreadPool(int min, int max) : minthread(min), maxthread(max),
        stopthread(false), idlethread(min), currentthread(min)
    {
        if (min < 0 || max < min) {
            throw std::invalid_argument("Invalid thread pool size parameters");
        }
        manager = new std::thread(&MthreadPool::Manager, this);
        counttask = 0;
        addcount = 0;
        taskQueues.resize(max);
        taskQueueConditions.resize(max);
        for (int i = 0; i < max; i++)
        {
            std::thread t(&MthreadPool::Worker, this, i);
            workers.insert(std::make_pair(t.get_id(), std::move(t)));
        }
    }

    ~MthreadPool()
    {
        stopthread = true;
        for (auto& cond : taskQueueConditions) {
            cond.notify_all();
        }
        for (auto& it : workers)
        {
            if (it.second.joinable())
            {
                it.second.join();
            }
        }
        if (manager->joinable())
        {
            manager->join();
        }
        delete manager;
    }

    template
    auto AddTask(F&& f, Args&&... args) 
        -> std::future::type> {
        using return_type = typename std::result_of::type;

        auto task = std::make_shared< std::packaged_task >(
            std::bind(std::forward(f), std::forward(args)...)
        );
        
        std::future res = task->get_future();
        {
            std::lock_guard lock(queueMutex);
            if (stopthread)
                throw std::runtime_error("enqueue on stopped ThreadPool");
            if (totalQueueSize.load() >= maxQueueSize) {
                throw std::runtime_error("Task queue is full");
            }
            int queueIndex = nextQueueIndex++;
            queueIndex %= maxthread;
            taskQueues[queueIndex].emplace([task]() { (*task)(); });
            addcount++;
            totalQueueSize++;
        }
        taskQueueConditions[queueIndex].notify_one();
        return res;
    }

private:
    const int maxQueueSize = 100;  // 任务队列最大长度
    int minthread;
    int maxthread;
    std::atomic stopthread;
    std::atomic idlethread;
    std::atomic currentthread;
    std::atomic exithread{ 0 };
    std::atomic counttask;
    std::atomic addcount;
    std::map workers;
    std::thread* manager;
    std::vector>> taskQueues;
    std::mutex queueMutex;
    std::mutex exitidMutex;
    std::vector taskQueueConditions;
    std::atomic nextQueueIndex{0};
    std::vector exit_id;
    std::atomic totalQueueSize{0};

    void Manager(void)
    {
        while (!stopthread.load())
        {
            std::this_thread::sleep_for(std::chrono::seconds(3));
            int idle = idlethread.load();
            int cur = currentthread.load();
            int totalSize = 0;
            {
                std::lock_guard lock(queueMutex);
                for (const auto& queue : taskQueues) {
                    totalSize += queue.size();
                }
                totalQueueSize = totalSize;
            }

            if (idle > cur / 2 && cur > minthread && totalSize < cur / 2)
            {
                exithread.store(2);
                for (auto& cond : taskQueueConditions) {
                    cond.notify_all();
                }
                std::lock_guard lock(exitidMutex);
                for (auto id : exit_id)
                {
                    auto thread = workers.find(id);
                    if (thread != workers.end())
                    {
                        (*thread).second.join();
                        workers.erase(thread);
                    }
                }
                exit_id.clear();
            }
            else if (idle == 0 && cur < maxthread && totalSize > cur * 2)
            {
                std::thread t(&MthreadPool::Worker, this, currentthread);
                workers.insert(std::make_pair(t.get_id(), std::move(t)));
                currentthread++;
                idlethread++;
            }
        }
    }

    void Worker(int queueIndex) {
        while (!stopthread.load()) {
            std::function task;
            {
                std::unique_lock lock(queueMutex);
                taskQueueConditions[queueIndex].wait(lock, [this, queueIndex] {
                    return stopthread || !taskQueues[queueIndex].empty();
                });
                if (stopthread && taskQueues[queueIndex].empty())
                    return;
                if (exithread.load() > 0) {
                    currentthread--;
                    exithread--;
                    idlethread--;
                    std::lock_guard lock(exitidMutex);
                    exit_id.emplace_back(std::this_thread::get_id());
                    return;
                }
                task = std::move(taskQueues[queueIndex].front());
                taskQueues[queueIndex].pop();
                totalQueueSize--;
            }
            idlethread--;
            try {
                task();
                counttask++;
            } catch (const std::exception& e) {
                std::cerr << "Exception caught in worker thread: " << e.what() << std::endl;
                // 可以添加更多异常处理逻辑,如重试任务
            }
            idlethread++;
        }
    }
};

// 示例使用
void sampleTask() {
    std::this_thread::sleep_for(std::chrono::milliseconds(500));
    std::cout << "Task completed." << std::endl;
}

int main() {
    try {
        MthreadPool pool(2, 5);
        for (int i = 0; i < 20; ++i) {
            pool.AddTask(sampleTask);
        }
        std::this_thread::sleep_for(std::chrono::seconds(10));
    } catch (const std::exception& e) {
        std::cerr << "Exception in main: " << e.what() << std::endl;
    }
    return 0;
}

解释

  • 构造函数参数检查:在构造函数中,添加了对最小线程数和最大线程数的检查。如果最小线程数小于 0 或者最大线程数小于最小线程数,会抛出 std::invalid_argument 异常,避免创建无效的线程池。
  • 任务队列边界检查:在 AddTask 方法中,增加了对任务队列总长度的检查。如果任务队列的总长度达到了预设的最大长度 maxQueueSize,会抛出 std::runtime_error 异常,防止任务队列无限增长,导致内存溢出。
  • 异常处理完善:在 Worker 函数的异常处理部分,可以进一步添加重试逻辑,例如对于一些可重试的异常,可以尝试重新执行任务,提高任务执行的成功率。
11. 线程池的跨平台兼容性优化

虽然 C++ 标准库提供了跨平台的线程和同步机制,但在不同的操作系统上,线程的行为和性能可能会有所差异。为了提高线程池的跨平台兼容性和性能,可以根据不同的操作系统进行一些优化。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#ifdef _WIN32
#include 
#else
#include 
#endif

class MthreadPool {
public:
    MthreadPool(int min, int max) : minthread(min), maxthread(max),
        stopthread(false), idlethread(min), currentthread(min)
    {
        if (min < 0 || max < min) {
            throw std::invalid_argument("Invalid thread pool size parameters");
        }
        manager = new std::thread(&MthreadPool::Manager, this);
        counttask = 0;
        addcount = 0;
        taskQueues.resize(max);
        taskQueueConditions.resize(max);
        for (int i = 0; i < max; i++)
        {
            std::thread t(&MthreadPool::Worker, this, i);
            setThreadPriority(t);
            workers.insert(std::make_pair(t.get_id(), std::move(t)));
        }
    }

    ~MthreadPool()
    {
        stopthread = true;
        for (auto& cond : taskQueueConditions) {
            cond.notify_all();
        }
        for (auto& it : workers)
        {
            if (it.second.joinable())
            {
                it.second.join();
            }
        }
        if (manager->joinable())
        {
            manager->join();
        }
        delete manager;
    }

    template
    auto AddTask(F&& f, Args&&... args) 
        -> std::future::type> {
        using return_type = typename std::result_of::type;

        auto task = std::make_shared< std::packaged_task >(
            std::bind(std::forward(f), std::forward(args)...)
        );
        
        std::future res = task->get_future();
        {
            std::lock_guard lock(queueMutex);
            if (stopthread)
                throw std::runtime_error("enqueue on stopped ThreadPool");
            if (totalQueueSize.load() >= maxQueueSize) {
                throw std::runtime_error("Task queue is full");
            }
            int queueIndex = nextQueueIndex++;
            queueIndex %= maxthread;
            taskQueues[queueIndex].emplace([task]() { (*task)(); });
            addcount++;
            totalQueueSize++;
        }
        taskQueueConditions[queueIndex].notify_one();
        return res;
    }

private:
    const int maxQueueSize = 100;  // 任务队列最大长度
    int minthread;
    int maxthread;
    std::atomic stopthread;
    std::atomic idlethread;
    std::atomic currentthread;
    std::atomic exithread{ 0 };
    std::atomic counttask;
    std::atomic addcount;
    std::map workers;
    std::thread* manager;
    std::vector>> taskQueues;
    std::mutex queueMutex;
    std::mutex exitidMutex;
    std::vector taskQueueConditions;
    std::atomic nextQueueIndex{0};
    std::vector exit_id;
    std::atomic totalQueueSize{0};

    void Manager(void)
    {
        while (!stopthread.load())
        {
            std::this_thread::sleep_for(std::chrono::seconds(3));
            int idle = idlethread.load();
            int cur = currentthread.load();
            int totalSize = 0;
            {
                std::lock_guard lock(queueMutex);
                for (const auto& queue : taskQueues) {
                    totalSize += queue.size();
                }
                totalQueueSize = totalSize;
            }

            if (idle > cur / 2 && cur > minthread && totalSize < cur / 2)
            {
                exithread.store(2);
                for (auto& cond : taskQueueConditions) {
                    cond.notify_all();
                }
                std::lock_guard lock(exitidMutex);
                for (auto id : exit_id)
                {
                    auto thread = workers.find(id);
                    if (thread != workers.end())
                    {
                        (*thread).second.join();
                        workers.erase(thread);
                    }
                }
                exit_id.clear();
            }
            else if (idle == 0 && cur < maxthread && totalSize > cur * 2)
            {
                std::thread t(&MthreadPool::Worker, this, currentthread);
                setThreadPriority(t);
                workers.insert(std::make_pair(t.get_id(), std::move(t)));
                currentthread++;
                idlethread++;
            }
        }
    }

    void Worker(int queueIndex) {
        while (!stopthread.load()) {
            std::function task;
            {
                std::unique_lock lock(queueMutex);
                taskQueueConditions[queueIndex].wait(lock, [this, queueIndex] {
                    return stopthread || !taskQueues[queueIndex].empty();
                });
                if (stopthread && taskQueues[queueIndex].empty())
                    return;
                if (exithread.load() > 0) {
                    currentthread--;
                    exithread--;
                    idlethread--;
                    std::lock_guard lock(exitidMutex);
                    exit_id.emplace_back(std::this_thread::get_id());
                    return;
                }
                task = std::move(taskQueues[queueIndex].front());
                taskQueues[queueIndex].pop();
                totalQueueSize--;
            }
            idlethread--;
            try {
                task();
                counttask++;
            } catch (const std::exception& e) {
                std::cerr << "Exception caught in worker thread: " << e.what() << std::endl;
                // 可以添加更多异常处理逻辑,如重试任务
            }
            idlethread++;
        }
    }

    void setThreadPriority(std::thread& t) {
#ifdef _WIN32
        HANDLE handle = t.native_handle();
        SetThreadPriority(handle, THREAD_PRIORITY_NORMAL);
#else
        pthread_t nativeThread = t.native_handle();
        sched_param param;
        param.sched_priority = sched_get_priority_min(SCHED_OTHER);
        pthread_setschedparam(nativeThread, SCHED_OTHER, ¶m);
#endif
    }
};

// 示例使用
void sampleTask() {
    std::this_thread::sleep_for(std::chrono::milliseconds(500));
    std::cout << "Task completed." << std::endl;
}

int main() {
    try {
        MthreadPool pool(2, 5);
        for (int i = 0; i < 20; ++i) {
            pool.AddTask(sampleTask);
        }
        std::this_thread::sleep_for(std::chrono::seconds(10));
    } catch (const std::exception& e) {
        std::cerr << "Exception in main: " << e.what() << std::endl;
    }
    return 0;
}

解释

  • 条件编译:使用 #ifdef _WIN32 进行条件编译,根据不同的操作系统选择不同的线程优先级设置方法。在 Windows 系统上,使用 SetThreadPriority 函数设置线程优先级;在其他系统上,使用 pthread_setschedparam 函数设置线程优先级。
  • 线程优先级设置:在创建每个工作线程时,调用 setThreadPriority 函数设置线程的优先级。这样可以根据不同的操作系统和任务需求,合理调整线程的优先级,提高线程池的性能。

你可能感兴趣的:(机器学习,算法,人工智能,计算机视觉,c++)