C++11信号量实现

#pragma once
#include 
#include 
#include 

class semaphore
{
public:
	struct closed_exception {};
public:
	explicit semaphore(size_t cnt = 0)
		: m_cnt(cnt)
		, m_opened(true)
	{}
	void open()
	{
		std::lock_guard _(m_mtx);
		m_opened = true;
	}
	void close()
	{
		std::lock_guard _(m_mtx);
		m_opened = false;
		m_evt.notify_all();
	}
	void wait()
	{
		std::unique_lock lck(m_mtx);
		m_evt.wait(lck, [this]
		{
			if (!m_opened)
			{
				throw closed_exception();
			}
			return m_cnt > 0;
		});
		--m_cnt;
	}
	void post(size_t n = 1)
	{
		std::unique_lock lck(m_mtx);
		m_cnt += n;
		m_evt.notify(lck, n);
	}
protected:
	class guard_
	{
	public:
		explicit guard_(size_t& waiters)
			: waiters_(waiters)
		{
			++waiters_;
		}
		~guard_()
		{
			--waiters_;
		}
	private:
		size_t & waiters_;
	};
	class event_
	{
	public:
		void wait(std::unique_lock& lck)
		{
			guard_ _(m_waiters);
			m_cnd.wait(lck);
		}
		template 
		void wait(std::unique_lock& lck, F f)
		{
			guard_ _(m_waiters);
			m_cnd.wait(lck, f);
		}
		void notify(std::unique_lock& lck, size_t n = 1)
		{
			auto times = std::min(n, m_waiters);
			for (size_t i = 0; i < times; i++)
			{
				m_cnd.notify_one();
			}
		}
		void notify_all()
		{
			m_cnd.notify_all();
		}
	private:
		std::condition_variable m_cnd;
		size_t                  m_waiters{ 0 };
	};
private:
	std::mutex m_mtx;
	event_     m_evt;
	size_t     m_cnt{ 0 };
	bool       m_opened{ false };
};

测试代码

#include 
#include 
#include "semaphore.h"

int main(int argc, char* argv[])
{
	semaphore sem;
	size_t cnt = 0;
	std::thread thds[2];
	for (size_t i = 0; i < 2; i++)
	{
		thds[i] = std::move(std::thread(([&]
		{
			try
			{
				for (;;)
				{
					sem.wait();
					std::cout << "thread:" << std::this_thread::get_id() << ", semaphore post: " << cnt++ << std::endl;
				}
			}
			catch (const semaphore::closed_exception&)
			{
				std::cout << "thread:" << std::this_thread::get_id() << ", semaphore closed" << std::endl;
			}
		})));
	}

	for (size_t i = 0; i < 10; i++)
	{
		std::this_thread::sleep_for(std::chrono::seconds(1));
		sem.post();
	}
	sem.close();
	for (size_t i = 0; i < 2; i++)
	{
		thds[i].join();
	}
	return 1;
}


你可能感兴趣的:(C/C++)