手写能任务窃取的线程池

目录

function_wrapper.hpp:

stealing_queue.hpp

thread_pool_steal_hpp

参考:《C++并发编程实战》

对于thread_pool_steal.hpp的代码有改动,不然运行不了

function_wrapper.hpp:


//包装可调用对象,对外消除对象型别,还需要有一个函数调用符
//
//
//私有成员:
//      1.一个虚基类(没有的话在3很难定义一个指向实例化类的指针)
//      struct impl_base{
//              1.一个纯虚函数用来让派生类继承3执行函数
//              2.虚析构函数
//      }
//      2.因为需要接受任意类型的可调用对象,所以内部需要封装一个模板类(派生自1>)
//      template
//      struct impl_type{
//              1.成员变量:传入的函数
//              2.构造函数:移动构造成员变量
//              3.成员函数:执行函数
//      }
//      3.一个指向实例化后的指针(在这里看出确实需要一个基类,可以通过基类指针构
造派生类指针)
//公有成员:
//      1.构造函数,因为接收的对象是任意类型,所以是模板构造函数,并且使用万能引
用和完美转发实例化私有成员的3.
//      2.()运算符的重载通过调用私有成员2.的3.实现
//      3.默认构造函数 = default
//      4.移动拷贝构造函数
//      5.移动复制构造函数
//   
 

#ifndef _FUNCTION_WRAPPER_HPP_
#define _FUNCTION_WRAPPER_HPP_

#include
class function_wrapper
{
	struct impl_base{
		virtual void call() = 0;
		virtual ~impl_base(){};
	};	
	template
	struct impl_type:impl_base{
		Function f;
		impl_type(Function&& f_):f(std::move(f_)){}
		void call(){ f(); }
	};
	std::unique_ptr impl;
public:
	function_wrapper() = default;
	template
	function_wrapper(Function&& f):impl(new impl_type(std::move(f)))	{}
	void operator()(){
		impl->call();
	}
	function_wrapper(function_wrapper&& other):impl(std::move(other.impl)) {}
	function_wrapper& operator=(function_wrapper&& other){
		impl = std::move(other.impl);
		return *this;
	}

	function_wrapper(function_wrapper& other) = delete;
	function_wrapper(const function_wrapper& other) = delete;
	function_wrapper& operator=(const function_wrapper& other)=delete;
	
};
#endif

stealing_queue.hpp


//可以进行任务窃取的队列
//
//私有成员:
//      1.一个双端队列,pop操作在队头进行,steal操作在队尾进行,存储的内容为function_wrapper类
//      2.互斥 保证安全,因为不存储线程,所以不需要条件变量传递线程运行信息
//公有成员:
//      1.默认构造函数
//      2.push
//      3.try_pop
//      4.try_steal基本和pop都一样,就是弹出队尾元素//在线程池中定义一个存储窃取
队列的vector,vector的索引代表每个线程的标识,这样每个线程就可以通过这个vector访>问窃取队列。
//      5.empty
 

#ifndef _STEALING_QUEUE_HPP
#define _STEALING_QUEUE_HPP


#include"function_wrapper.hpp"
#include
#include
class work_stealing_queue
{
	typedef function_wrapper data_type;
	std::deque the_queue;
	mutable std::mutex mut;
public:
	work_stealing_queue(){}
	work_stealing_queue(const work_stealing_queue& othre) = delete;
	work_stealing_queue& operator=(const work_stealing_queue& othre) = delete;
	
	void push(data_type data){
		std::lock_guard lk(mut);
		the_queue.push_front(std::move(data));
	}
	bool try_pop(data_type& data){
		std::lock_guard lk(mut);
		if(the_queue.empty())
			return false;
		else{
			data = std::move(the_queue.front());
			the_queue.pop_front();
			return true;
		}
	}
	bool try_steal(data_type& data){
		std::lock_guard lk(mut);
		if(the_queue.empty())
                        return false;
                else{
                        data = std::move(the_queue.front());
                        the_queue.pop_back();
                        return true;
                }

	}
	bool empty() const{
		std::lock_guard lk(mut);
		return the_queue.empty();
	}
};


#endif

thread_pool_steal_hpp:

//私有成员:
//      1.控制线程正常运行的原子变量,抛出异常设置为true
//      2.线程池的池队列,基于普通的线程安全队列
//      3.提供索引的队列存储指向窃取队列的指针
//      4.存放线程的队列
//      5.封装可联结线程
//      6.静态本地线程变量 本地线程队列
//      7.静态本地线程变量 索引
//
//      8.work_thread工作函数(任务函数)
//      9.判断能否从 本地线程队列获取任务
//      10.判断能否从 线程池队列获取任务
//      11.判断能否从 其他线程窃取任务
//公有成员:
//      1.构造函数:初始化原子变量,可联结线程类。
//      {
//              try
//              {       
//                      for()
//                      {初始化提供索引的队列   }
//                      for()
//                      {初始化线程}
//              }
//              catch
//      }
//      2.提交任务函数//为了获取返回值应该返回future,传入各种函数,所以应该是模
板函数{
//              1.传入函数打包给pakage_task
//              2.获取future
//              3.判断是传入本地线程队列还是线程池队列
//              4.return future;
//      }
//      3.run_package_task//work_thread的主要部分
//      {
//              if(判断从哪个队列获得任务)
//              {
//                      task()
//              }
//              else 交出cpu时间。
//        }

//        4.析构函数
 

#ifndef _THREAD_POOL_STEAL_HPP_
#define _THREAD_POOL_STEAL_HPP_

#include "threadsafe_queue.hpp"
#include "ThreadRAII.h"
#include "stealing_queue.hpp"
#include "function_wrapper.hpp"
#include "stealing_queue.hpp"

#include
#include
#include
#include
#include
class thread_pool_steal
{
	typedef function_wrapper task_type;
	std::atomic_bool done;
	threadsafe_queue pool_work_queue;
	std::vector> queues;
	std::vector threads;
       	join_threads joiner;
	static thread_local work_stealing_queue* local_work_queue;
	static thread_local unsigned my_index ;
	void work_thread(unsigned index){
		my_index = index;
		local_work_queue = queues[my_index].get();
		while(!done){
			run_pending_task();
		}	
	}
	bool pop_from_local(task_type& task){
		return local_work_queue && local_work_queue->try_pop(task);
	}
	bool pop_from_pool(task_type& task){
		return pool_work_queue.try_pop(task);
	}
	bool pop_from_steal(task_type& task){
		work_stealing_queue steal_queue;
		for(unsigned i = 0; itry_steal(task))
			{
				return true;
			}
		}
		return false;
	}
public:
	thread_pool_steal():done(false),joiner(threads)
	{
		unsigned number = std::thread::hardware_concurrency();
		try{
			for(unsigned i = 0; i(new work_stealing_queue));
			}
			for(unsigned i = 0; i
	std::future::type> submit(Function f)
	{
		typedef typename std::result_of::type result_type;
		std::packaged_task task(f);
		std::future res(task.get_future());
		int index = 0;
		for(auto& ptr:queues){
			if(ptr->empty()){
				ptr->push(std::move(task));
				index = -1;
				break;
			}
			index++;
		}
		if(index>=0){
			pool_work_queue.push(std::move(task));
		}
		return res;
	}
	void run_pending_task(){
		task_type task;
		if(pop_from_local(task) || pop_from_pool(task) || pop_from_steal(task))
		{
			task();
		}
		else{
			std::this_thread::yield();
		}
	}

};
thread_local work_stealing_queue* thread_pool_steal::local_work_queue = nullptr;
thread_local unsigned thread_pool_steal::my_index = 0;
#endif

你可能感兴趣的:(c++,算法,开发语言)