[C++] 一个通用协程类模板

[C++] 一个通用协程类模板

文章目录

  • [C++] 一个通用协程类模板
    • 源码
    • 使用
      • 基本概念
        • 协程函数
        • 协程
      • 定义协程函数
      • 生成器型协程
      • 普通协程
      • STL的协程类
      • 常用函数
      • 一些调用检查
    • 异常

源码

#ifndef __MYCORO_H__
#define __MYCORO_H__

#include 
#include 

#define DELETE_COPY_FUNCTION(cls) \
	cls (const cls&) = delete; \
	cls& operator= (const cls&) = delete;

#define DELETE_MOVE_FUNCTION(cls) \
	cls (cls&&) = delete; \
	cls& operator= (cls&&) = delete;

#define DELETE_COPY_MOVE_FUNCTION(cls) \
	DELETE_COPY_FUNCTION(cls) \
	DELETE_MOVE_FUNCTION(cls)

namespace my_coro {

	struct Followable {
		virtual bool follow() const noexcept = 0;
		virtual Followable const* next_pfollowable() const noexcept = 0;
		void go() const noexcept {
			Followable const* f = this;
			while(f != nullptr && f->follow()) {
				f = f->next_pfollowable();
			}
		}
	};

	template<typename RsmT>
	struct Resumable : public Followable {
		virtual bool send(RsmT const&) noexcept = 0;
	};
	template<>
	struct Resumable<void> : public Followable {
		virtual bool send() noexcept = 0;
	};

	template<typename RtnT>
	struct CoroPromiseWithReturn {
		std::optional<RtnT> rtn_value;
		void return_value(RtnT&& v) noexcept { rtn_value.emplace(std::move(v)); }
	};
	template<>
	struct CoroPromiseWithReturn<void> {
		void return_void() const noexcept {}
	};

	template<typename RsmT>
	struct CoroPromiseWithResume {
		struct ResumeAwaiter {
			const RsmT* const& d;
			bool await_ready() const noexcept { return false; }
			void await_suspend(std::experimental::coroutine_handle<>) const noexcept {}
			const RsmT& await_resume() const noexcept { return *d; }
		};
		const RsmT* prsm_value;
		auto get_awaiter() noexcept {
			return ResumeAwaiter{ prsm_value };
		}
	};
	template<>
	struct CoroPromiseWithResume<void> {
		auto get_awaiter() noexcept { return std::experimental::suspend_always(); }
	};

	template<typename SpdT>
	struct CoroPromiseWithSuspend {
		const SpdT* pspd_value;
		Resumable<SpdT>* presumable_coroutine;
	};
	template<>
	struct CoroPromiseWithSuspend<void> {};

	template<typename SpdT, typename RsmT, typename RtnT>
	class CoroPromiseCore :
		public CoroPromiseWithSuspend<SpdT>,
		public CoroPromiseWithResume<RsmT>,
		public CoroPromiseWithReturn<RtnT> {
		DELETE_COPY_MOVE_FUNCTION(CoroPromiseCore);
	protected:
		std::optional<std::exception_ptr> pexcept;
	public:
		CoroPromiseCore() noexcept :
			CoroPromiseWithSuspend<SpdT>(),
			CoroPromiseWithResume<RsmT>(),
			CoroPromiseWithReturn<RtnT>(),
			pexcept(std::nullopt) {
		}
	public:
		auto yield_value(const SpdT& v) noexcept {
			this->pspd_value = std::addressof(v);
			return this->get_awaiter();
		}
		auto initial_suspend() const noexcept { return std::experimental::suspend_never(); }
		auto final_suspend() noexcept { return std::experimental::suspend_always(); }
		void unhandled_exception() noexcept {
			try {
				std::rethrow_exception(std::current_exception());
			} catch(std::exception& e) {
				pexcept.emplace(std::make_exception_ptr(new std::exception(e)));
				std::cerr << "exception in stack of the coroutine was copied into heap now: " << e.what() << std::endl;
			} catch(...) {
				pexcept.emplace(std::current_exception());
			}
		}
	public:
		bool follow() const noexcept {
			Resumable<SpdT>* pobj = this->presumable_coroutine;
			if(pobj != nullptr) {
				const SpdT* pspd = this->pspd_value;
				if(pspd != nullptr) {
					return pobj->send(*pspd);
				}
			}
			return false;
		}
		void rethrow_if_failed() const {
			if(*pexcept) {
				std::rethrow_exception(*pexcept);
			}
		}
	};

	template<typename T>
	struct CoroIterator {
		T& co;
		CoroIterator(T & co) noexcept : co(co) {}
		const CoroIterator& operator++ () const noexcept { co.send(); return *this; }
		bool operator!= (CoroIterator const& end) const noexcept { return co; }
		const typename T::suspend_type& operator* () const noexcept { return co.recv(); }
	};

#define CORO_COMMON_FUNCTION(Coro, SpdT, RsmT, RtnT) \
		DELETE_COPY_FUNCTION(Coro);																								\
	Coro& operator= (Coro &&) = delete;																					\
	public:																																			\
		using suspend_type = SpdT;																								\
		using resume_type = RsmT;																									\
		using return_type = RtnT;																									\
		struct promise_type : public CoroPromiseCore {					\
			using CoroPromiseCore::CoroPromiseCore;								\
			Coro get_return_object() noexcept {																			\
				return Coro(*this);																										\
			}																																				\
		};																																				\
		using handle_type = std::experimental::coroutine_handle;		\
		static_assert(std::is_void_v == 0,													\
									"suspend_type can not be void");														\
	protected:																																	\
		promise_type& promise;																										\
		handle_type handle;																												\
	public:																																			\
		explicit Coro(promise_type& promise) noexcept :														\
			promise(promise), handle(handle_type::from_promise(promise)) {}					\
		Coro(Coro &&self) noexcept : promise(self.promise) {}											\
		virtual ~Coro() noexcept { handle.destroy(); }														\
		void rethrow_if_failed() const { return promise.rethrow_if_failed(); }		\
		const SpdT& recv() const noexcept { return *promise.pspd_value; }					\
		bool finalized() const noexcept { return handle.done(); }									\
		template && std::is_same_v>::type>										\
		const T& get_return() const noexcept { return *promise.rtn_value; }				\
		void link(Resumable & rsmobj) noexcept {														\
			promise.presumable_coroutine = std::addressof(rsmobj);									\
		}																																					\
		void unlink() noexcept { promise.presumable_coroutine = nullptr; }				\
																																							\
	public:																																			\
		operator bool() const noexcept { return !finalized(); }										\
		const SpdT& operator*() const noexcept { return recv(); }									\
		template																	\
		Coro<_SpdT, SpdT, _RtnT>& operator | (																		\
			Coro<_SpdT, SpdT, _RtnT>& robj) noexcept {															\
			link(robj);																															\
			return robj;																														\
		}																																					\
																																							\
	public:																																			\
		bool follow() const noexcept override { return promise.follow(); }				\
		Followable const* next_pfollowable() const noexcept override {						\
			Resumable const* prsmobj = promise.presumable_coroutine;					\
			return prsmobj;																													\
		}																																					\
	private:

	template<typename SpdT, typename RsmT = void, typename RtnT = void>
	class Coro final :
		public Resumable<RsmT> {
		CORO_COMMON_FUNCTION(Coro, SpdT, RsmT, RtnT)
	public:
		bool send(RsmT const& v) noexcept override {
			promise.prsm_value = std::addressof(v);
			handle.resume();
			return !finalized();
		}
	};

	template<typename SpdT, typename RtnT>
	class Coro<SpdT, void, RtnT> final :
		public Resumable<void> {
		CORO_COMMON_FUNCTION(Coro, SpdT, void, RtnT)
	public:
		bool send() noexcept override {
			handle.resume();
			return !finalized();
		}
		CoroIterator<Coro> begin() { return CoroIterator<Coro>(*this); }
		CoroIterator<Coro> end() { return CoroIterator<Coro>(*this); }
	};

#undef CORO_COMMON_FUNCTION

}

#undef DELETE_COPY_MOVE_FUNCTION
#undef DELETE_MOVE_FUNCTION
#undef DELETE_COPY_FUNCTION

#endif

使用

基本概念

协程函数

含有co_yieldco_awaitco_return的函数为协程函数. 协程函数不能有return.

协程函数调用后返回协程对象. 即协程由协程函数而生. 可通过协程对象控制协程运行.

协程函数的返回类型需为符合要求的协程类. 本文提供的协程类模板可用于生成符合要求的协程类.

协程

协程可向外界发送值, 外界亦可向协程发送值. 前者称挂起值, 后者称恢复值.
协程结束后亦可有返回值. 该三类值的类型在本文提供的类模板里均可自定义.
其中, 协程的挂起值类型不能为void.

定义协程函数

使用下列语法产生一个协程类. =void表示此处可留空, 默认为void.

my_coro::Coro<挂起值类型, 恢复值类型=void, 返回值类型=void>

之后可定义协程函数, 令其返回类型为协程类. 协程使用co_yield向外界发送挂起值和接收恢复值.
下面的协程函数能生成挂起值类型为int, 恢复值类型为void, 返回值类型为void的协程. 这也是生成器型协程.

Coro<int> coro_f() {
	for(int i = 0; i < 10; ++i) {
		co_yield i;
	}
}

生成器型协程

生成器型协程即恢复值类型为void的协程. 有时简称生成器.

生成器型协程不需要恢复值, 因此为该类协程提供了迭代器. 使用如下代码依次遍历协程的所有挂起值.

Coro<挂起值类型> co = coro_f();
for(const 挂起值类型& i : co) {
	对i的操作...
}

若不使用迭代器, 亦可像普通协程那样遍历协程.

Coro<int> co = coro_f();
while(co) {
	std::cout << *co << std::endl;
	co.send();
}

STL亦提供有一个生成器, 功能基本相同. 亦可用范围for循环迭代.

std::experimental::generator<挂起值类型>

普通协程

普通协程即有恢复值的协程. 通过向协程发送恢复值来驱动协程运行.
下面是无返回值的例子.

Coro<int, int> coro_f() {
	int i = -1;
	while(i != 0) {
		i = 10 + (co_yield i);
	}
}

下面是有返回值的例子.

Coro<int, int, int> sum_coro() {
	int s = 0;
	int i;
	do {
		i = co_yield 0;
		s += i;
	} while(i != 0);
	co_return s;
}

外界在协程对象上调用send函数向协程发送恢复值. 使用get_return函数获取返回值.

Coro<int, int, int> sco = sum_coro();
int i = 0;
while(i <= 100) {
	sco.send(i);
}
std::cout << sco.get_return() << std::endl;

STL的协程类

std::futurestd::experiment::generator是两个符合要求的协程类. 前者要求协程不能有co_yield, 后者要求协程不能有co_return 非void值;.

当需要快速编写协程函数时可使用std::future, 以方便快速使用co_await.

常用函数

本文的协程对象实现了operator bool运算符, 可用于直接测试协程是否已返回.

if(co) {
	...
}

同时还实现了operator*运算符, 可用于获取挂起值. 亦可使用recv()函数.

std::cout << *co << std::endl;
std::cout << co.recv() << std::endl;

可让一个协程的输出直接接到另一个的协程的输入. 只要两个的协程的挂起值类型和恢复值类型相同. 连接协程可使用管道运算符(按位或运算符)operator|, 亦可使用link()函数.

co1 | co2

co1的挂起值将直接输入给co2. 若要启动协程的多米诺式运行, 可在第一个协程上调用go(). 所有协程都实现了Followable接口, 可使用该接口自行控制协程的链式运行.

co1.go();

一些调用检查

send()函数在协程的恢复值类型为void时无参数, 非void时有参数.
get_return()函数在协程的返回值类型定义为void时不可调用.

异常

协程内抛出异常后, 将直接退出. 协程对象将记录该异常. 外界可用rethrow_if_exception()重抛出该异常.
栈上C++异常将被复制, 以让外界访问. 即std::exception&型异常将转为std::exception*, 注意记得调用delete释放内存. 建议不要抛出栈上异常, 除非自行修改源码令其支持复制其他栈上的自定义异常.

你可能感兴趣的:([C++] 一个通用协程类模板)