c++11 实现读写锁

 


Note: C++ 17里已经引进了读写锁 std::shared_mutex , 其lock()即以写方式加锁, 其lock_shared()即以读方式加锁。 
https://en.cppreference.com/w/cpp/thread/shared_mutex


实现一个读写锁类, 可以有两种方式获取锁,读方式,写方式。 允许多个"读线程"同时进入临界区,但是同一时刻只允许一个"写线程"进入临界区。
当有写线程进入临界区时,不允许任何其他读或写线程同时进入。  写线程优先。

方法:

在类中增加成员, 记录当前正在临界区的“读线程”,"写线程"数目, 等待进入临界区的“读线程”,"写线程"数目。
增加2个信号量成员 用于"写线程"的等待和唤醒, 用于“读线程”的等待和唤醒。

常用场景
多个读线程频繁访问临界区,偶尔有一个写线程进入临界区。

 

C++11 实现:

#pragma once
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

class cppReadWriteLock
{
public:
    cppReadWriteLock():
        mWaitReadThreadNum(0),
        mWaitWriteTrheadNum(0),
        mReadingThreadNum(0),
        mWritingThreadNum(0){
        }

    ~cppReadWriteLock() {};

    void getReadLock() {
        unique_lock uniLock(mMyMutex);
        if (mWritingThreadNum || mWaitWriteTrheadNum) { //写优先,只要有线程在等待写,则不能让读线程得到机会。
            ++mWaitReadThreadNum;
            while (mWritingThreadNum || mWaitWriteTrheadNum) {
                mReadThreadCV.wait(uniLock);
            }
            --mWaitReadThreadNum;
        }
        ++mReadingThreadNum;
    }

    void getWriteLock() {
        unique_lock uniLock(mMyMutex);
        if (mWritingThreadNum || mReadingThreadNum) {
            ++mWaitWriteTrheadNum;
            while (mWritingThreadNum || mReadingThreadNum) {
                mWriteThreadCV.wait(uniLock);
            }
            --mWaitWriteTrheadNum;
        }
        ++mWritingThreadNum;
    }

    void releaseReadLock() {
        unique_lock uniLock(mMyMutex);
        --mReadingThreadNum;
        if (mWaitWriteTrheadNum) {//有写线程在等待的话,直接尝试唤醒一个写线程,即使还有其他线程在读。写优先!

            mWriteThreadCV.notify_one();
         
        }
    }

    void releaseWriteLock() {
        unique_lock uniLock(mMyMutex);
        --mWritingThreadNum;
        if (mWaitWriteTrheadNum) {//写优先
            mWriteThreadCV.notify_one();
        }
        else {
            mReadThreadCV.notify_all();//通知所有被阻塞的read线程
        }
    }

private:
    int mWaitReadThreadNum, mReadingThreadNum;

    int mWaitWriteTrheadNum, mWritingThreadNum;

    mutex mMyMutex;
    condition_variable mReadThreadCV;//用于“读线程”的等待和唤醒。
    condition_variable mWriteThreadCV;//用于"写线程"的等待和唤醒
};

测试代码:

生产者线程5个, 产生0-49数字, 将产生的数字存到全局变量list尾部。第一个线程产生0~9.
消费者线程10个,用于从全局变量list头部get数据,并打印,被get到的数据从list剔除;
观察者线程2个, 用于打印当前list中的元素。

#include "pch.h"
#include "cppReadWriteLock.h"


//生产者线程5个, 产生0-49数字, 将产生的数字存到全局变量list尾部。第一个线程产生0~9.
//消费者线程10个,用于从全局变量list头部get数据,并打印,被get到的数据从list剔除;
//观察者线程2个, 用于打印当前list中的元素。
const int produceThreadNum = 5;
const int consumeThreadNum = 10;
const int watchThreadNum = 2;

list listCache;
int totalTargetNum = 50;//所有的生产者的目标是总共生产50个数字。
int currentProducedNum = 0;
int currentConsumedNum = 0;

cppReadWriteLock gWrLock;

void produceThread(int stIdx, int num) {

    for (int i = stIdx; i < stIdx+num; i++) {
        gWrLock.getWriteLock();
        listCache.push_back(i);
        currentProducedNum++;

        cout << "Produce " << i << endl;
        gWrLock.releaseWriteLock();

        //sleep:
        std::this_thread::sleep_for(std::chrono::milliseconds(rand() % 15 + 1));
    }

}
void consumeThread() {
    bool bStop = false;

    while (true) {
        gWrLock.getWriteLock();
        if (!listCache.empty()) {
            int topNumber = listCache.front();
            listCache.pop_front();
            currentConsumedNum++;
            cout << "Consume " << topNumber << endl;
        }
        if (currentConsumedNum >= totalTargetNum) {
            bStop = true;
        }
        gWrLock.releaseWriteLock();

        if (bStop)
        {
            break;
        }
        //sleep:
        std::this_thread::sleep_for(std::chrono::milliseconds(rand() % 15 + 1));
    }
}

void watchThread() {
    bool bIshouldStop = false;
    while (true) {
        gWrLock.getReadLock();
        if (!listCache.empty()) {
            cout << "Watch:  ";
            for (const auto& it : listCache) {
                cout << it << "--";
            }
            cout << endl;
        }
        if (currentConsumedNum >= totalTargetNum) {
            bIshouldStop = true;
        }
        gWrLock.releaseReadLock();

        if (bIshouldStop) {
            break;
        }
        //sleep:
        std::this_thread::sleep_for(std::chrono::milliseconds(rand() % 15 + 1));
    }
}

void main()
{
    vector watchThreadsVec;
    vector consumeThreadsVec;
    vector produceThreadsVec;

    for (int i = 0; i < watchThreadNum; ++i) {
        watchThreadsVec.push_back(thread(watchThread));
    }

    for (int i = 0; i < consumeThreadNum; ++i) {
        consumeThreadsVec.push_back(thread(consumeThread));
    }

    for (int i = 0; i < produceThreadNum; ++i) {
        produceThreadsVec.push_back(thread(produceThread, i*(totalTargetNum / produceThreadNum), totalTargetNum / produceThreadNum));
    }

    for (auto& it : watchThreadsVec) {
        it.join();
    }
    for (auto& it : produceThreadsVec) {
        it.join();
    }
    for (auto& it : consumeThreadsVec) {
        it.join();
    }

}

Ref:

https://github.com/bo-yang/read_write_lock

 

你可能感兴趣的:(c++,stl,C++)