c++ 线程池

项目要多线程化,写生产者消费者模型。想到很久以前看过线程池, 所以就写写熟练一下


参考 http://blog.csdn.net/kankan231/article/details/24499947


/***********************************任务接口,抽象类 task.h***************************************/
#ifndef __task
#define __task

class task{
public:
    virtual void * run() = 0;
};

#endif

/**********************************thread_pool.h 线程池接口**************************************/

#ifndef __thread_pool
#define __thread_pool

#include "task.h"
#include <pthread.h>
#include <queue>
using namespace std;

class thread_pool{
private:
    pthread_mutex_t mutex;
    pthread_cond_t empty;
    pthread_cond_t full;
    queue< task* > tasks;
    int * thread_nums;
    int thread_num;
    int max_task_num;
    bool end;
    pthread_t *thread_ptr;
public:
    thread_pool(int thread_num=10, int max_task_num=100);
    void create();
    void add_task(task *);
    void destroy();
    friend void * call_back(void *);
    ~thread_pool();
};

#endif

/*****************************************thread_pool.cpp 线程池实现*******************************************/

#include <pthread.h>
#include "task.h"
#include "thread_pool.h"

using namespace std;
void * call_back(void*);

thread_pool::thread_pool(int thread_num, int max_task_num):thread_num(thread_num), max_task_num(max_task_num), end(false), thread_ptr(NULL){
    pthread_mutex_init(&mutex, NULL);
    pthread_cond_init(&empty, NULL);
    pthread_cond_init(&full, NULL);
    create();
}

void thread_pool::create(){
    thread_ptr = new pthread_t[thread_num];
    thread_nums = new int[thread_num];
    for(int i=0; i<thread_num; i++){
        int ret = 0;
        if( ret = pthread_create(&thread_ptr[i], NULL, call_back, (void*)this) ){
            cout << "pthread_create error:error code " << ret << endl;
            exit(1);
        }
    }
    return ;
}

void thread_pool::add_task(task * t){
    if (end){
        return ;
    }
    pthread_mutex_lock(&mutex);
    while( tasks.size() == max_task_num ){
        pthread_cond_wait(&full, &mutex);
    }
    if (end){
        pthread_mutex_unlock(&mutex);
        return ;
    }
    if (tasks.empty()){
        tasks.push(t);
        pthread_cond_broadcast(&empty);
    }else{
        tasks.push(t);
    }
    pthread_mutex_unlock(&mutex);
    return ;
}

void thread_pool::destroy(){
    if (end){
        return ;
    }
    pthread_mutex_lock(&mutex);
    end = true;
//    while(!tasks.empty()) tasks.pop();            退出时未完成的任务放弃
    pthread_mutex_unlock(&mutex);

    pthread_cond_broadcast(&empty);
    pthread_cond_broadcast(&full);
    for(int i=0; i<thread_num; i++){
        pthread_join(thread_ptr[i], NULL);
    }
    delete[] thread_ptr;
    delete[] thread_nums;
    pthread_mutex_destroy(&mutex);
    pthread_cond_destroy(&empty);
    pthread_cond_destroy(&full);
    return ;
}

void *call_back(void * argv){
    thread_pool* pool = (thread_pool*)argv;
    cout << "start thread " << pthread_self() << endl;
    while(1){
        pthread_mutex_lock(&(pool->mutex));
        if (pool->end && pool->tasks.empty()){                    //保证将tasks中任务做完才退出
            pthread_mutex_unlock(&(pool->mutex));
            pthread_exit(NULL) ;
        } //没有这段代码会使退出时广播信息过后又运行到后面wait导致线程一直wait
        task * t;
        while (pool->tasks.empty()){
            pthread_cond_wait(&(pool->empty), &(pool->mutex));
            if (pool->end){
                pthread_mutex_unlock(&(pool->mutex));
                pthread_exit(NULL);
            } //没有这段退出代码会导致线程一直wait
        }
        if (pool->tasks.size() == pool->max_task_num){
            t = pool->tasks.front();
            pool->tasks.pop();
            pthread_cond_broadcast(&(pool->full));
        }else{
            t = pool->tasks.front();
            pool->tasks.pop();
        }
        pthread_mutex_unlock(&(pool->mutex));
        t->run();
    }
    cout << "end thread " << pthread_self() << endl;
    pthread_exit(NULL);
}

thread_pool::~thread_pool(){
    destroy();
}

/*****************************************************test.cpp***************************************/
#include "task.h"
#include "thread_pool.h"
#include <unistd.h>

using namespace std;

class mytask: public task{
    int num;
public:
    mytask(int num=10):num(num){
    }
    void * run(){
        cout << "task " << num << endl;
        return NULL;
    }
};

#define TASK_NUM 20
int main(){
    thread_pool pool(3, 10);
    mytask* t_array[TASK_NUM];
    for(int i=0; i<TASK_NUM; i++){
        t_array[i] = new mytask(i+10);
        pool.add_task(t_array[i]);
    }
//    sleep(2);
    pool.destroy();
    for(int i=0; i<TASK_NUM; i++){
        delete t_array[i];
//        ;
    }
    return 0;
}

/*********************************************************Makefile.am*****************************************/

bin_PROGRAMS = test
test_SOURCES = task.h thread_pool.h thread_pool.cpp test.cpp
LIBS += -lpthread

你可能感兴趣的:(c++ 线程池)