C++ 学习笔记(28)C++ thread_pool

代码来自网络。只做私人记录,参考用

#include 
#include 
#include 
#include 
#include 
#include 
// thread
#include 
#include 
#include 
#include 
#include 

namespace YHL {

	class thread_pool {
	private:
		// 一个线程池 + 一个任务队列, 线程不断检查是否可以执行任务
		std::vector< std::thread > pool;
		std::queue< std::function > tasks;
		// sunchronization
		std::mutex mtx;
		std::condition_variable cv;
		bool stop;
	public:
		thread_pool(const size_t);
		~thread_pool();

		// 获取一个线程
		std::function get_thread();

		// 拓展线程池的容量
		void add_thread(const size_t);

		template
		auto enqueue(F&& fun, Args&& ...args)
			-> std::future::type>;
	};

	thread_pool::thread_pool(const size_t init_size) 
			:stop(false) {
		for(size_t i = 0; i pool.emplace_back(std::move(get_thread()));
	}

	// 获取一个线程
	std::function thread_pool::get_thread() {
		auto task = [this] {
			for(;;) {    // 实现线程池的关键 : 每个线程轮询队列是否有未处理的任务
				std::function cur;
				do{
					std::unique_lock lck(this->mtx);
					this->cv.wait(lck, [this]{ return this->stop || !this->tasks.empty();});
					
					if(this->stop or this->tasks.empty())
						return;

					cur = std::move(this->tasks.front());
					this->tasks.pop();
				} while(0);
				
				cur();  // 本次任务结束, 继续轮询任务队列,把可以执行的任务放到线程中
			}
		};
		return task;
	}

	// 拓展线程池的容量
	void thread_pool::add_thread(const size_t extend) {
		for(size_t i = 0;i < extend; ++i)
			this->pool.emplace_back(std::move(get_thread()));
	}

	// 放入新的任务到队列中去
	template
	auto thread_pool::enqueue(F&& fun, Args&& ...args)
			-> std::future< typename std::result_of::type > {
		using return_type = typename std::result_of::type;

		auto packed_task = std::make_shared< std::packaged_task >(
				std::bind(std::forward(fun), std::forward(args)...)
			);

		{
			std::unique_lock lck(this->mtx);

			if(stop == true)
				throw std::runtime_error("enqueue task on stopped pool\n");

			this->tasks.emplace([packed_task](){ (*packed_task)(); });
		}

		std::future res = packed_task->get_future();
		this->cv.notify_one();
		return res;
	}

	inline thread_pool::~thread_pool() {
		{
			std::unique_lock lck(this->mtx);
			stop = true;
		}
		this->cv.notify_all();
		for(auto &it : pool)
			it.join();
		pool.clear();
		pool.shrink_to_fit();
	}

}

namespace test {
	int cnt = 0;
	std::mutex m;

	int fun() {
		std::this_thread::sleep_for(std::chrono::seconds(2));
		std::cout << "id  :  " << std::this_thread::get_id() << std::endl;
		std::lock_guard lck(m);
		return 1022;
	}
}

int main() {
	YHL::thread_pool pool(4);
	for(int i = 0;i < 10; ++i) {
		auto result = pool.enqueue(test::fun);
		std::cout << "answer  :  " << result.get() << std::endl;
	}

	pool.add_thread(2);
	for(int i = 0;i < 20; ++i) {
		auto result = pool.enqueue(test::fun);
		std::cout << "answer  :  " << result.get() << std::endl;
	}
	return 0;
}

 

你可能感兴趣的:(c++)