C++协程库实现

概念

        协程,是一个程序组件,其功能其实就是执行一段可由用户随时中断或继续运行的代码,可与异步机制结合。一个线程中,可存在多个协程。

实现

        windows平台下具有Fiber概念,其API提供了创建CreateFiberEx、删除DeleteFiber、切换SwitchToFiber的接口,Fiber可看作是协程在windows平台下的实现。

        linux平台下我这里使用glibc提供的ucontext实现(感谢网友的贡献,让我抄一下代码),需手动实现保存运行环境的上下文。

        我为了简化编程,使用了C++11新特性下的std::function、thread_local。使用thread_local可直接获得本线程的协程调度器,简便,且不用自己根据线程环境来获取到对应的协程调度器,但其局限性就是不可跨线程调度协程。本文只实现了linux、windows两个平台的环境。

        其他具体的信息,这里就不说了,网上也挺多关于协程概念的描述。

        提供的接口:

/*
 * 开启当前线程的协程环境,只有开启后,才能对协程进行操作
 * 只对调用者所在线程有效
 */
extern void OpenCoroutineEnv();

/*
 * 关闭当前线程的协程环境,关闭后,当前线程的协程的所有操作将无效
 * 只对调用者所在线程有效
 */
extern void CloseCoroutineEnv();

/*
 * 在当前线程中创建协程
 */
extern Awaitable CreateCoroutine(Callable &&callable);

/*
 * 在当前线程中创建协程,并立即运行
 */
extern Awaitable CreateAndStartCoroutine(Callable &&callable);

/* 
 * 唤醒协程,在该协程上一次跳出的地方继续运行
 * 不可跨线程唤醒协程,只能对当前线程的协程操作,否则会得到意想不到的效果
 */
extern void ResumeCoroutine(Awaitable awaiter);

/*
 * 跳出当前协程
 */
extern void YieldCoroutine();

/*
 * 终止当前协程
 */
extern void EndCoroutine();

/*
 * 获取协程的状态
 */
extern CoroutineStatus GetCoroutineStatus(Awaitable awaiter);

/*
 * 获取当前正在运行的协程
 */
extern Awaitable GetCurrentCoroutine();

// 简化操作的命名
#define co_yield YieldCoroutine();
#define co_return EndCoroutine();
#define co_create(c) CreateCoroutine([&](){c;});
#define co_await(c) CreateAndStartCoroutine([&](){c;});
#define co_resume(c) ResumeCoroutine(c);
#define co_status GetCoroutineStatus(GetCurrentCoroutine());
#define co_running GetCurrentCoroutine();

源码

        语言环境:C++11

头文件:Coroutine.h

#ifndef COROUTINE_H
#define COROUTINE_H

#include 

enum class CoroutineStatus
{
    DEAD, // 终止状态
    RUNNABLE, // 就绪状态
    RUNNING, // 运行状态
    SUSPEND // 暂停状态
};

using Awaitable = long long;

using Callable = std::function;

/*
 * 开启当前线程的协程环境
 * 只对调用者所在线程有效
 */
extern void OpenCoroutineEnv();

/*
 * 关闭当前线程的协程环境,关闭后,当前线程的协程的所有操作将无效
 * 只对调用者所在线程有效
 */
extern void CloseCoroutineEnv();

/*
 * 在当前线程中创建协程
 */
extern Awaitable CreateCoroutine(Callable &&callable);

/*
 * 在当前线程中创建协程,并立即运行
 */
extern Awaitable CreateAndStartCoroutine(Callable &&callable);

/* 
 * 唤醒协程,在该协程上一次跳出的地方继续运行
 * 不可跨线程唤醒协程,只能对当前线程的协程操作,否则会得到意想不到的效果
 */
extern void ResumeCoroutine(Awaitable awaiter);

/*
 * 跳出当前协程
 */
extern void YieldCoroutine();

/*
 * 终止当前协程
 */
extern void EndCoroutine();

/*
 * 获取协程的状态
 */
extern CoroutineStatus GetCoroutineStatus(Awaitable awaiter);

/*
 * 获取当前正在运行的协程
 */
extern Awaitable GetCurrentCoroutine();

// 操作命名
#define co_yield YieldCoroutine();
#define co_return EndCoroutine();
#define co_create(c) CreateCoroutine([&](){c;});
#define co_await(c) CreateAndStartCoroutine([&](){c;});
#define co_resume(c) ResumeCoroutine(c);
#define co_status GetCoroutineStatus(GetCurrentCoroutine());
#define co_running GetCurrentCoroutine();

#endif // COROUTINE_H

源文件:Coroutine.cpp

#include "Coroutine.h"
#include 

#define OS_WIN 0
#define OS_LINUX 1

#ifdef __WIN32
#define OS OS_WIN
#elif defined(__linux__) 
#define OS OS_LINUX
#endif

#if OS == OS_WIN
#include 
typedef LPVOID ContextType;
#elif OS == OS_LINUX
#include  // memcpy
#include 
typedef ucontext_t ContextType;
#endif

// 栈大小
#ifndef CO_STACK_SIZE
#define CO_STACK_SIZE 1024 * 1024
#endif

struct Schedule;
struct Coroutine;

/*
 * 删除协程及其占用的空间
 */
static void DeleteCoroutine(Coroutine *co);

/*
 * 协程函数入口
 */
static void CoroutineMain(void *lpParameter);

#if OS == OS_LINUX
/*
 * 保存协程的当前栈空间信息
 */
static void SaveStack(Coroutine *co, char *top);
#endif

struct Coroutine
{
	CoroutineStatus status; // 协程状态
	ContextType ctx; // 协程上下文
	Callable func; // 协程执行函数
#if OS == OS_LINUX
	std::size_t size; // 占用的栈空间大小
	std::size_t cap;
	char *stack = nullptr;
#endif
};

struct Schedule
{
	long long cur_co;  // 当前执行的协程ID
	ContextType main; // 原上下文
	std::map active_co_list;// 活动中的协程
	std::map dead_list;// 已消亡的协程
#if OS == OS_LINUX
	char stack[CO_STACK_SIZE];//栈空间
#endif
};

static thread_local Schedule *_S_schedule = nullptr;

void OpenCoroutineEnv()
{
	if(_S_schedule != nullptr)
	{
		return;
	}
    _S_schedule = new Schedule;
	_S_schedule->cur_co = -1;
	_S_schedule->active_co_list.clear();
	_S_schedule->dead_list.clear();
#if OS == OS_WIN
	_S_schedule->main = ConvertThreadToFiberEx(NULL, FIBER_FLAG_FLOAT_SWITCH);
#endif
}

void CloseCoroutineEnv()
{
	if(_S_schedule == nullptr)
	{
		return;
	}
	// 不能在协程中关闭
	if(_S_schedule->cur_co > 0)
	{
		return;
	}

	for(auto &p : _S_schedule->active_co_list)
	{
		DeleteCoroutine(p.second);
	}
	for(auto &p : _S_schedule->dead_list)
	{
		DeleteCoroutine(p.second);
	}

	_S_schedule->active_co_list.clear();
	_S_schedule->dead_list.clear();
	delete _S_schedule;
	_S_schedule = nullptr;

#if OS == OS_WIN
	ConvertFiberToThread();
#endif
}

Awaitable CreateCoroutine(Callable &&callable)
{
	if(_S_schedule == nullptr)
	{
		return -1;
	}

	long long key = 1;
	Coroutine *co = new Coroutine();
	co->status = CoroutineStatus::RUNNABLE;
	co->func = std::move(callable);

#if OS == OS_WIN
	co->ctx = CreateFiberEx(CO_STACK_SIZE, 0, FIBER_FLAG_FLOAT_SWITCH, CoroutineMain, _S_schedule);
#elif OS == OS_LINUX
	getcontext(&co->ctx);
	co->ctx.uc_stack.ss_sp = _S_schedule->stack;
	co->ctx.uc_stack.ss_size = CO_STACK_SIZE;
	co->ctx.uc_link = &_S_schedule->main;
	makecontext(&co->ctx, reinterpret_cast(CoroutineMain), 1, _S_schedule);
#endif
	if(!_S_schedule->active_co_list.empty())
	{
		key = _S_schedule->active_co_list.rbegin()->first + 1;
	}
	while(key <= 0 || _S_schedule->active_co_list.count(key) > 0)
	{
		++key;
	}
	_S_schedule->active_co_list.insert(std::make_pair(key, co));
	return key;
}

Awaitable CreateAndStartCoroutine(Callable &&callable)
{
	Awaitable ret = CreateCoroutine(std::move(callable));
	ResumeCoroutine(ret);
	return ret;
}

void ResumeCoroutine(Awaitable awaiter)
{
	if(_S_schedule == nullptr)
	{
		return;
	}

	long long coid = awaiter;
	for(auto &p : _S_schedule->dead_list)
	{
		DeleteCoroutine(p.second);
	}

	_S_schedule->dead_list.clear();
	std::map::iterator found =  _S_schedule->active_co_list.find(coid);
	if(found == _S_schedule->active_co_list.end())
	{
		return;
	}

	Coroutine *co = found->second;
	switch(co->status)
	{
	case CoroutineStatus::RUNNABLE:
	case CoroutineStatus::SUSPEND:
		co->status = CoroutineStatus::RUNNING;
		_S_schedule->cur_co = coid;
#if OS == OS_WIN
		SwitchToFiber(co->ctx);
#elif OS == OS_LINUX
		memcpy(_S_schedule->stack + CO_STACK_SIZE -  co->size, co->stack, co->size);
		swapcontext(&_S_schedule->main, &co->ctx);
#endif
		break;
	default: break;

	}

}

#if OS == OS_LINUX
void SaveStack(Coroutine *co, char *top)
{
	char dummy = 0;
	if(co->cap < (std::size_t)(top - &dummy))
	{
		if(co->stack != nullptr)
		{
			delete[] co->stack;
		}
		co->cap = top - &dummy;
		co->stack = new char[co->cap];
	}

	co->size = top - &dummy;
	memcpy(co->stack, &dummy, co->size);
}
#endif

void YieldCoroutine()
{
	if(_S_schedule == nullptr)
	{
		return;
	}
	std::map::iterator found = _S_schedule->active_co_list.find(_S_schedule->cur_co);
	if(found == _S_schedule->active_co_list.end())
	{
		return;
	}
	Coroutine *co = found->second;
	co->status = CoroutineStatus::SUSPEND;
	_S_schedule->cur_co = -1;

#if OS == OS_WIN
	SwitchToFiber(_S_schedule->main);
#elif OS == OS_LINUX
	SaveStack(co, _S_schedule->stack + CO_STACK_SIZE);
	swapcontext(&co->ctx, &_S_schedule->main);
#endif
}

void EndCoroutine()
{
	if(_S_schedule == nullptr)
	{
		return;
	}
	std::map::iterator found =  _S_schedule->active_co_list.find(_S_schedule->cur_co);

	if(found == _S_schedule->active_co_list.end())
	{
		return;
	}

	Coroutine *co = found->second;
	co->status = CoroutineStatus::DEAD;

	_S_schedule->dead_list.insert(*found);
	_S_schedule->active_co_list.erase(found);
	_S_schedule->cur_co = -1;

#if OS == OS_WIN
	SwitchToFiber(_S_schedule->main);
#elif OS == OS_LINUX
	swapcontext(&co->ctx, &_S_schedule->main);
#endif
}

CoroutineStatus GetCoroutineStatus(Awaitable awaiter)
{
	if(_S_schedule == nullptr)
	{
		return CoroutineStatus::DEAD;
	}
	long long coid = awaiter;
	std::map::iterator found =  _S_schedule->active_co_list.find(coid);
	if(found == _S_schedule->active_co_list.end())
	{
		return CoroutineStatus::DEAD;
	}
	return found->second->status;
}

Awaitable GetCurrentCoroutine()
{
	if(_S_schedule == nullptr)
	{
		return -1;
	}
	return _S_schedule->cur_co;
}

void DeleteCoroutine(Coroutine *co)
{
#if OS == OS_WIN
	DeleteFiber(co->ctx);
#elif OS == OS_LINUX
	if(co->stack != nullptr)
	{
		delete[] co->stack;
	}
#endif
	delete co;
}

void CoroutineMain(void *lpParameter)
{
	Schedule *s = reinterpret_cast(lpParameter);
	std::map::iterator found =  s->active_co_list.find(s->cur_co);
	if(found == s->active_co_list.end())
	{
		return;
	}

	Coroutine *co = found->second;
	if(co->func)
	{
		co->func();
	}
	s->cur_co = -1;
	s->active_co_list.erase(found);

#if OS == OS_WIN
	SwitchToFiber(s->main);
#elif OS == OS_LINUX
	swapcontext(&co->ctx, &s->main);
#endif
}

测试代码

#include 
#include "Coroutine.h"

static void func(int min, int max)
{
    int c = 0;
    for(int i = min; i <= max; ++i)
    {
        std::cout << i << std::endl;
        if(++c % 5 == 0)
        {
            co_yield
        }
        if(c % 10 == 0)
        {
            co_return
        }
    }
}

static void func2(int min, int max)
{
    int c = 0;
    for(int i = min; i <= max; ++i)
    {
        std::cout << i << std::endl;
        if(++c % 5 == 0)
        {
            co_yield
        }
    }
}

static void main_pro()
{
    OpenCoroutineEnv();
    Awaitable cid1 = co_await(func(10, 29));
    Awaitable cid2 = co_await(func(30, 49));
    Awaitable cid3 = co_await(func2(70, 100));

    co_resume(cid1);
    co_resume(cid2);
    co_resume(cid3);

    co_resume(cid1);
    co_resume(cid2);
    co_resume(cid3);

    co_resume(cid1);
    co_resume(cid2);
    co_resume(cid3);

    co_resume(cid1);
    co_resume(cid2);
    co_resume(cid3);
    
    CloseCoroutineEnv();
}

int main()
{
    main_pro();
    return 0;
}

测试结果

如有问题,欢迎指出!

你可能感兴趣的:(C++应用,c++,协程)