使用linux系统函数和c++03类写的一个简单线程池

#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

struct task_t
{
	void*	(*fun)(void*);
	void*	arg;
};

void thread_err(char* err, int s)
{
	fprintf(stderr, "%s:%s",err, strerror(s));
}

class Thread_pool;

//简单包装一下thread_pool中实际调用的函数。直接调用会和linux系统函数类型不匹配
void* work_package(void* obj);

class Thread_pool
{
public:
	Thread_pool():shutdown(false){};
	~Thread_pool(){};

	int init(int threadnum)
	{
		int s;
		s = pthread_mutex_init(&(queue_lock), NULL);
		if(s != 0)
		{
			thread_err("pthread_mutex_init", s);
		}

		s = pthread_cond_init(&(queue_ready), NULL);
		if(s != 0)
		{
			thread_err("pthread_cond_init", s);
		}

		for(int i = 0; i < threadnum; ++i)
		{
			pthread_t pid;
			s = pthread_create(&pid, NULL, &work_package, this);
			if(s != 0)
			{
				thread_err("pthread_create", s);
			}
			//pthread_t is similar to pointer, is a flag;
			threads.push_back(pid);
		}
	}

	void destory()
	{
		pthread_mutex_lock(&queue_lock);
		shutdown = true;
		pthread_mutex_unlock(&queue_lock);

		pthread_cond_broadcast(&(queue_ready));

		int s;
		void* res;

		for(int i = 0; i < threads.size(); ++i)
		{
			s = pthread_join(threads[i],&res);
			if(s != 0)
				thread_err("pthread_join", s);

			free(res);
		}

		pthread_mutex_destroy(&(queue_lock));
		pthread_cond_destroy(&(queue_ready));


	}

	void* add_task(void *(*fun)(void* arg), void* arg)
	{
		task_t* task = new task_t;
		task->fun = fun;
		task->arg = arg;

		pthread_mutex_lock(&(queue_lock));

		tasks.push(task);

		pthread_mutex_unlock(&(queue_lock));

		pthread_cond_signal(&(queue_ready));

		return NULL;
	}

	void work_fun(void* arg)
	{
		while(true)
		{
			pthread_mutex_lock(&queue_lock);
			while(tasks.size() == 0 && !shutdown)
			{
				pthread_cond_wait(&(queue_ready), &queue_lock);
			}

			if(shutdown)
			{
				pthread_mutex_unlock(&queue_lock);
				pthread_exit(NULL);
			}

			task_t* task = tasks.front();
			tasks.pop();

			pthread_mutex_unlock(&queue_lock);

			(*(task->fun))(task->arg);

			delete task;
			task = NULL;
		}

		pthread_exit(NULL);
	}


private:
	std::queue tasks;
	std::vector threads;

	pthread_mutex_t queue_lock;
	pthread_cond_t queue_ready;

	bool shutdown;
};

void* work_package(void* obj)
{
	Thread_pool* tmp = static_cast(obj);

	tmp->work_fun(NULL);

	return NULL;
}


void* my_test(void* arg)
{
	int* i = static_cast(arg);

	printf("%d:function\n", *i);
	return NULL;
}

int main()
{
//	std::printf("aa\n");
	Thread_pool pool;
	pool.init(4);
	
	int tmp[] = {0,1,2,3,4,5,6,7,8,9};

	for(int i = 0; i < 10; ++i)
		pool.add_task(&my_test, &(tmp[i]));
	sleep(2);
	pool.destory();
}

你可能感兴趣的:(使用linux系统函数和c++03类写的一个简单线程池)