手把手理解C++20协程的编译实现

考虑下面的协程代码

#include 
#include 

using namespace std;

class Resumable
{

};

Resumable func() {
    cout << "hello";
    co_await std::suspend_always();
    cout << " world";
}


int main()
{
    
}

编译报错

error: unable to find the promise type for this coroutine
   13 |     co_await std::suspend_always();
      |     ^~~~~~~~

为什么?

其实编译器在编译时,会希望生成如下的代码:

/* 经过编译器优化后的 func 函数 */
Resumable func()
{
    Frame *frame = operator new(size);	//	size = 函数形参大小 + 局部变量大小
    Rumable::promise_type promise;
    coroutine_handle *handle = coroutine_handle<>::from_promise(&promise);
    Resumable res = promise.get_return_object();	//	call the Resumable constructor 

    co_await promise.initial_suspend();	//	in some ways, this is a coroutine constructor
    try {
        //  func-body
        cout << "hello";
        co_await std::suspend_always();
        cout << " world";
        //	func-body end
    }catch (...) {
        promise.unhandled_exception();	//	coroutine exception handle
    }
    co_await promise.final_suspend();	//	in some ways, this is a coroutine destructor

    return res;
}

通过上面的代码,可以引出两个问题:

  1. 已知协程co_await可以完成上下文切换,那这个函数中co_await具体是怎么调用的?
  2. promise_type 哪里来?
  3. Resubmable如何实现?

同样,从上面的代码中可以推出,promise_type至少应该含有以下代码:

class promise_type
{
public:
    auto get_return_object();
    auto initial_suspend();
    void unhandled_exception();
    auto final_suspend();
    void return_void();
};

抱着上面三个问题,看看Resumable的实现规范。

Resumable的编译实现

class Resumable
{
public: /* 用户自定义实现部分 */
    class promise_type;	// 见上个代码块
    
	/* 
		用户的其他自定义实现代码
	*/

};

解决上面提出的问题:

  1. 已知协程co_await可以完成上下文切换,那这个函数中co_await具体是怎么调用的?

先继续存疑

  1. promise_type 哪里来?

答:从 Resumable 中由用户手动定义而来,且必须实现一些特定方法。

  1. Resubmable如何实现?

答:Resumable 必须包含 promise_type 子类型(typedef也算),其余没什么讲究。

再提出一些新问题:

  1. 协程如何将一个值从函数内co_await到函数外?
  2. 看起来Resumable在编译优化后的func里没有被用到,只在最后返回的时候return了一下,为什么不用promise直接代替Resumable?
    换句话说:为什么要给promise加一层外套作为返回类型?C++为什么要这样设计?

以下是未解决问题列表:

  1. 已知协程co_await可以完成上下文切换,那这个函数中co_await具体是怎么调用的?
  2. C++为什么要采用给promise加一层外套作为返回类型这样的设计方式?

总结一下

从上面可以看出,co_await 之类的协程关键字依然存在,这说明此处的编译优化并不是针对协程的,那为什么要这样做呢?

答案是为了更好的管理协程,可以看到,一次小小的协程函数调用覆盖了诞生、运行、错误处理、消亡等各个部分,这为将来高可用的框架奠定了基础,但对于写hello world的人不得不说,真***复杂。

C++有一个设计规范,叫做一个人只做一件事,在这里promise_type用来管理协程的生命周期。Resumable用来作为返回值。

如果说上面讲的都是协程规范的话,那么接下来要讲的部分就是具体协程的实现,看看 co_await 到底是如何调用的?

Awaitable对象的实现规范

回到最初的起点:

Resumable func() {
    cout << "hello";
    co_await std::suspend_always();
    cout << " world";
}

这里co_await std::suspend_always();调用的是标准库coroutine中的函数,直接来看看他的实现:

// 17.12.5 Trivial awaitables
/// [coroutine.trivial.awaitables]
struct suspend_always
{
  constexpr bool await_ready() const noexcept { return false; }

  constexpr void await_suspend(coroutine_handle<>) const noexcept {}

  constexpr void await_resume() const noexcept {}
};

这就是co_await expr 的通用标准实现,写的再通俗点就是:

class Awaiter
{
public:
    bool await_ready();
    auto await_suspend(coroutine_handle<> handle);
    auto await_resume();
};
  • await_ready:该任务是否已经完成?若未完成则将调用await_suspend
  • await_suspend:是否中止该任务?若中止则该协程将调用权返回给 caller 。
    该函数返回bool、void、coro_handle三种之一,对于不同的返回值编译器提供了不同的实现。
    • void:程序内部调用handle.resume()以继续运行协程。
    • bool:返回true表示同意中止,否则继续执行
    • std::coroutine_handle<>:调用该handle的resume,随后调用权返回 caller
  • await_resume:用于返回协程值,可以是任意类型

编译器会通过以下两种路径对 await 语句进行优化:

Created with Raphaël 2.3.0 co_await expr promise_type 存在 await_transform(expr)? promise.await_transform(expr) Awaitable yes no

(下面我用伪代码表示了await_suspend在不同返回值下的编译代码)
在调用co_await 等协程关键字的位置,程序的操作大概类似这样:

if(!a.await_ready()) {


# if await_suspend_return_void
	try {
		result = a.await_suspend(handle);
# if !(await_suspend_return_bool && await_suspend_return_coroutine_handle)
		return_to_caller();
#endif
	} catch(...) {
		excpetion = std::current_exception();
		goto resume_point;
	}
	
#elif await_suspend_return_bool
if(!result)
	goto resume_point;
return_to_caller();

#elif await_suspend_return_coroutine_handle
result.resume();
return_to_caller();

# endif

resume_point:
if(exception)
	std::rethrow_exception(exception);
return a.await_resume();
}

到这里,编译器完成了对co_await关键字的优化,接下来看个实例。

range

众所周知,python里有这样一个函数可以这样用:

for i in range(0, 10):
    print(i, end=' ')

Out:

0 1 2 3 4 5 6 7 8 9 

这个函数本质上可以用 python 协程这样实现:
def my_range(low, high):
    print('are you ok?')
    while low < high:
        yield low
        low += 1


iter = my_range(0, 10)
while True:
    try:
        print(iter.send(None))
    except StopIteration:
        break

Out:

0 1 2 3 4 5 6 7 8 9 

如果用CPP20呢?首先写出来自己想要的执行代码,然后再考虑如何实现,说得高大上点,用测试驱动开发。

我们最终想要的是这样的效果:

int main()
{
    Resumable iter = range(low, high);
    while(true) {
        try{
          cout << iter.get() << " ";
          iter.resume();
        }catch (...) {
            break;
        }
    }
}

具体实现:

#include 
#include 
#include 

using namespace std;


class Awaiter
{
public:
    Awaiter(int val):val(val) {  }
    bool await_ready() { return false;}
    void await_suspend(coroutine_handle<> handle) {  }
    void await_resume() {  }

    int val;
};


class Resumable
{
public:
    class promise_type
    {
    public:
        auto get_return_object() { return Resumable(Handle::from_promise(*this)); }
        auto initial_suspend() noexcept { return std::suspend_never(); }
        auto final_suspend() noexcept { return std::suspend_never(); }
        void unhandled_exception() { throw; }
        void return_void() { }
        Awaiter await_transform(Awaiter awaiter) {
            cur_val = awaiter.val;
            return awaiter;
        }

        int cur_val = 0;
    };

    typedef coroutine_handle<promise_type> Handle;
    Resumable(Handle handle):handle(handle) {}

    void resume() { handle.resume(); }
    int get() { return handle.promise().cur_val; }

private:
    Handle handle;
};


Resumable range(int low, int high)
{
    while(low < high) {
        co_await Awaiter(low++);
    }
}


int main()
{
    int low = 0, high = 10;
    Resumable iter = range(low, high);
    while(true) {
        try{
            cout.flush() << iter.get() << " ";
            iter.resume();
        }catch (...) {
            break;
        }
    }
}

Out:

0 1 2 3 4 5 6 7 8 9 

当然,我上面为了学习理解所以强行将awaiter作了一个产出器,实际上这个工作应该交由co_yield来完成,他会调用promise_type.yield_value(expr),可以直接从promise中拿到数值,更为简便。

具体的例子可参考cpp reference 里的这篇文章。

HelloWorld的协程实现

考虑下面的代码:

async_void func()
{
	cout << "hello ";
	co_await std::suspend_always();
	cout << "world" << endl;
}


int main()
{
	auto f = func();
	f.resume();
}

经过协程优化以后:

async_void coro_func()
{
    Frame *frame = operator new(size);	//	size = 函数形参大小 + 局部变量大小
	async_void::promise_type promise;
	async_void ret = promise.get_return_objet();

	int status = 0;
	void resume() {
		switch(status) {
		case 0:
			return f0();
		case 1:
			return f1();
		}
	}
	void f0() { status = 1; cout << "hello "; }
	void f1() { cout << "world" << endl; }

	return ret;
}

写得有点累了,先发这么多,如果有人看就继续往下写。

class async_void
{
public:
	class promise_type {
	public:
	    auto get_return_object() { return async_void{Handle::from_promise(*this)}; }
	    auto initial_suspend() { return std::suspend_never(); }
	    void unhandled_exception() { throw; }
	    auto final_suspend() { return std::suspend_never(); }
	    void return_void() {}
	};
	
	typedef std::coroutine_handle<promise_type> Handle;
	explicit async_void(Handle h):handle(h) {}
	
	Handle handle;
	bool resume() {
		if(!handle.done())
			handle.resume();
		return !handle.done();
	}
}

async_void func()
{
	cout << "hello ";
	co_await std::suspend_always();
	cout << "world" << endl;
}

int main()
{
	auto coro = func();
	while(coro.resume());
}
async_void func()
{
    Frame *frame = operator new(size);				//	size = 函数形参大小 + 局部变量大小
    async_void::promise_type promise;
    coroutine_handle *handle = coroutine_handle<>::from_promise(&promise);
    async_void res = promise.get_return_object();	//	call the Resumable constructor 

    co_await promise.initial_suspend();				//	in some ways, this is a coroutine constructor
    try {
        cout << "hello ";
        co_await std::suspend_always();
        cout << "world" << endl;
    }catch (...) {
        promise.unhandled_exception();				//	coroutine exception handle
    }
    co_await promise.final_suspend();				//	in some ways, this is a coroutine destructor

    return res;
}
auto a = std::suspend_always();
if(!a.await_ready()) {
	try {
		a.await_suspend(handle);
		return_to_caller();
	}catch(...) {
		exception = std::current_exception();
		goto resume_point();
	}
}

resume_point:
	if(exception)
		std::rethrow_exception(exception);
	return a.await_resume();
async_void hi() {
	cout << "h";
	co_await std::suspend_always();
	cout << "i";
}

async_void func() {
	cout << "hello";
	co_await hi().await;
	cout << "world";
}

int main() {
	auto coro1 = hi();
	auto coro2 = func();
	coro1.resume();
	coro2.resume();
}

参考

  1. cpp reference
  2. 阿里安龙飞的视频
  3. 知乎启蒙帖

你可能感兴趣的:(C++,c++,协程,C++20,coroutine,stl)