#include
#include
using namespace std;
class Resumable
{
};
Resumable func() {
cout << "hello";
co_await std::suspend_always();
cout << " world";
}
int main()
{
}
error: unable to find the promise type for this coroutine
13 | co_await std::suspend_always();
| ^~~~~~~~
其实编译器在编译时,会希望生成如下的代码:
/* 经过编译器优化后的 func 函数 */
Resumable func()
{
Frame *frame = operator new(size); // size = 函数形参大小 + 局部变量大小
Rumable::promise_type promise;
coroutine_handle *handle = coroutine_handle<>::from_promise(&promise);
Resumable res = promise.get_return_object(); // call the Resumable constructor
co_await promise.initial_suspend(); // in some ways, this is a coroutine constructor
try {
// func-body
cout << "hello";
co_await std::suspend_always();
cout << " world";
// func-body end
}catch (...) {
promise.unhandled_exception(); // coroutine exception handle
}
co_await promise.final_suspend(); // in some ways, this is a coroutine destructor
return res;
}
通过上面的代码,可以引出两个问题:
同样,从上面的代码中可以推出,promise_type至少应该含有以下代码:
class promise_type
{
public:
auto get_return_object();
auto initial_suspend();
void unhandled_exception();
auto final_suspend();
void return_void();
};
抱着上面三个问题,看看Resumable的实现规范。
class Resumable
{
public: /* 用户自定义实现部分 */
class promise_type; // 见上个代码块
/*
用户的其他自定义实现代码
*/
};
解决上面提出的问题:
先继续存疑
答:从 Resumable 中由用户手动定义而来,且必须实现一些特定方法。
答:Resumable 必须包含 promise_type 子类型(typedef也算),其余没什么讲究。
再提出一些新问题:
co_await
到函数外?以下是未解决问题列表:
从上面可以看出,co_await 之类的协程关键字依然存在,这说明此处的编译优化并不是针对协程的,那为什么要这样做呢?
答案是为了更好的管理协程,可以看到,一次小小的协程函数调用覆盖了诞生、运行、错误处理、消亡等各个部分,这为将来高可用的框架奠定了基础,但对于写hello world的人不得不说,真***复杂。
C++有一个设计规范,叫做一个人只做一件事,在这里promise_type用来管理协程的生命周期。Resumable用来作为返回值。
如果说上面讲的都是协程规范的话,那么接下来要讲的部分就是具体协程的实现,看看 co_await 到底是如何调用的?
回到最初的起点:
Resumable func() {
cout << "hello";
co_await std::suspend_always();
cout << " world";
}
这里co_await std::suspend_always();
调用的是标准库coroutine
中的函数,直接来看看他的实现:
// 17.12.5 Trivial awaitables
/// [coroutine.trivial.awaitables]
struct suspend_always
{
constexpr bool await_ready() const noexcept { return false; }
constexpr void await_suspend(coroutine_handle<>) const noexcept {}
constexpr void await_resume() const noexcept {}
};
这就是co_await expr
的通用标准实现,写的再通俗点就是:
class Awaiter
{
public:
bool await_ready();
auto await_suspend(coroutine_handle<> handle);
auto await_resume();
};
await_ready
:该任务是否已经完成?若未完成则将调用await_suspend
await_suspend
:是否中止该任务?若中止则该协程将调用权返回给 caller 。void
:程序内部调用handle.resume()
以继续运行协程。bool
:返回true表示同意中止,否则继续执行std::coroutine_handle<>
:调用该handle的resume,随后调用权返回 callerawait_resume
:用于返回协程值,可以是任意类型编译器会通过以下两种路径对 await 语句进行优化:
(下面我用伪代码表示了await_suspend
在不同返回值下的编译代码)
在调用co_await 等协程关键字的位置,程序的操作大概类似这样:
if(!a.await_ready()) {
# if await_suspend_return_void
try {
result = a.await_suspend(handle);
# if !(await_suspend_return_bool && await_suspend_return_coroutine_handle)
return_to_caller();
#endif
} catch(...) {
excpetion = std::current_exception();
goto resume_point;
}
#elif await_suspend_return_bool
if(!result)
goto resume_point;
return_to_caller();
#elif await_suspend_return_coroutine_handle
result.resume();
return_to_caller();
# endif
resume_point:
if(exception)
std::rethrow_exception(exception);
return a.await_resume();
}
到这里,编译器完成了对co_await关键字的优化,接下来看个实例。
众所周知,python里有这样一个函数可以这样用:
for i in range(0, 10):
print(i, end=' ')
Out:
0 1 2 3 4 5 6 7 8 9
def my_range(low, high):
print('are you ok?')
while low < high:
yield low
low += 1
iter = my_range(0, 10)
while True:
try:
print(iter.send(None))
except StopIteration:
break
Out:
0 1 2 3 4 5 6 7 8 9
如果用CPP20呢?首先写出来自己想要的执行代码,然后再考虑如何实现,说得高大上点,用测试驱动开发。
我们最终想要的是这样的效果:
int main()
{
Resumable iter = range(low, high);
while(true) {
try{
cout << iter.get() << " ";
iter.resume();
}catch (...) {
break;
}
}
}
具体实现:
#include
#include
#include
using namespace std;
class Awaiter
{
public:
Awaiter(int val):val(val) { }
bool await_ready() { return false;}
void await_suspend(coroutine_handle<> handle) { }
void await_resume() { }
int val;
};
class Resumable
{
public:
class promise_type
{
public:
auto get_return_object() { return Resumable(Handle::from_promise(*this)); }
auto initial_suspend() noexcept { return std::suspend_never(); }
auto final_suspend() noexcept { return std::suspend_never(); }
void unhandled_exception() { throw; }
void return_void() { }
Awaiter await_transform(Awaiter awaiter) {
cur_val = awaiter.val;
return awaiter;
}
int cur_val = 0;
};
typedef coroutine_handle<promise_type> Handle;
Resumable(Handle handle):handle(handle) {}
void resume() { handle.resume(); }
int get() { return handle.promise().cur_val; }
private:
Handle handle;
};
Resumable range(int low, int high)
{
while(low < high) {
co_await Awaiter(low++);
}
}
int main()
{
int low = 0, high = 10;
Resumable iter = range(low, high);
while(true) {
try{
cout.flush() << iter.get() << " ";
iter.resume();
}catch (...) {
break;
}
}
}
Out:
0 1 2 3 4 5 6 7 8 9
当然,我上面为了学习理解所以强行将awaiter
作了一个产出器,实际上这个工作应该交由co_yield
来完成,他会调用promise_type.yield_value(expr)
,可以直接从promise中拿到数值,更为简便。
具体的例子可参考cpp reference 里的这篇文章。
考虑下面的代码:
async_void func()
{
cout << "hello ";
co_await std::suspend_always();
cout << "world" << endl;
}
int main()
{
auto f = func();
f.resume();
}
经过协程优化以后:
async_void coro_func()
{
Frame *frame = operator new(size); // size = 函数形参大小 + 局部变量大小
async_void::promise_type promise;
async_void ret = promise.get_return_objet();
int status = 0;
void resume() {
switch(status) {
case 0:
return f0();
case 1:
return f1();
}
}
void f0() { status = 1; cout << "hello "; }
void f1() { cout << "world" << endl; }
return ret;
}
写得有点累了,先发这么多,如果有人看就继续往下写。
class async_void
{
public:
class promise_type {
public:
auto get_return_object() { return async_void{Handle::from_promise(*this)}; }
auto initial_suspend() { return std::suspend_never(); }
void unhandled_exception() { throw; }
auto final_suspend() { return std::suspend_never(); }
void return_void() {}
};
typedef std::coroutine_handle<promise_type> Handle;
explicit async_void(Handle h):handle(h) {}
Handle handle;
bool resume() {
if(!handle.done())
handle.resume();
return !handle.done();
}
}
async_void func()
{
cout << "hello ";
co_await std::suspend_always();
cout << "world" << endl;
}
int main()
{
auto coro = func();
while(coro.resume());
}
async_void func()
{
Frame *frame = operator new(size); // size = 函数形参大小 + 局部变量大小
async_void::promise_type promise;
coroutine_handle *handle = coroutine_handle<>::from_promise(&promise);
async_void res = promise.get_return_object(); // call the Resumable constructor
co_await promise.initial_suspend(); // in some ways, this is a coroutine constructor
try {
cout << "hello ";
co_await std::suspend_always();
cout << "world" << endl;
}catch (...) {
promise.unhandled_exception(); // coroutine exception handle
}
co_await promise.final_suspend(); // in some ways, this is a coroutine destructor
return res;
}
auto a = std::suspend_always();
if(!a.await_ready()) {
try {
a.await_suspend(handle);
return_to_caller();
}catch(...) {
exception = std::current_exception();
goto resume_point();
}
}
resume_point:
if(exception)
std::rethrow_exception(exception);
return a.await_resume();
async_void hi() {
cout << "h";
co_await std::suspend_always();
cout << "i";
}
async_void func() {
cout << "hello";
co_await hi().await;
cout << "world";
}
int main() {
auto coro1 = hi();
auto coro2 = func();
coro1.resume();
coro2.resume();
}