C++协程入门

1 什么是协程

定义

  • 协程可以理解为用户态轻量级线程;
  • 协程拥有自己的上下文和栈;
  • 协程的切换和调度由用户定义,不用陷入内核;
  • 如同一个进程拥有多个线程,一个线程可以拥有多个协程。

优点

  • 极高的执行效率 因为协程切换不用陷入内核,是由用户程序定义切换逻辑,因此协程没有线程切换的开销。
  • 以同步代码的方式写异步逻辑 可开发异步IO。

缺点

由于协程是在单个线程内切换的,无法利用多核资源。结合多进程/多线程可以解决这个问题。

2 协程的实现

协程在一些脚本语言如Python、Lua中都已经很好地支持了(C++20也支持协程),但为了更好地学习它,还是有必要去逐步封装一个协程。本文主要利用Linux的ucontext库去封装一个Coroutine类,再与boost的基于fcontext的Coroutine库做一个对比。

2.1 基于ucontext的实现

ucontext.h 简介

头文件提供以下4个调用:

int getcontext(ucontext_t * ucp);
int setcontext(const ucontext_t *ucp);
void makecontext(ucontext_t *ucp, void(*func)(), int argc, ...);
int swapcontext(ucontext_t *oucp, ucontext_t *ucp);

以下逐个介绍它们:

  1. int getcontext(ucontext_t * ucp);
  • 获取当前上下文, 初始化ucp结构体, 将当前上下文保存到ucp中;
  • 成功不返回,失败返回-1, 并设置errno。
  1. void makecontext(ucontext_t *ucp, void(*func)(), int argc, ...);
  • 创建一个目标上下文 ,用于初始化一个协程,并且上下文需要
    • ucp来自getcontext调用;
    • 指定分配给上下文的栈uc_stack.ss_sp
    • 指定这块栈的大小uc_stack.ss_size,如32K, 64K, 128K;
    • 指定uc_stack.ss_flags,一般为0;
    • 指定后继上下文uc_link
  1. int setcontext(const ucontext_t *ucp);
  • 设置当前的上下文为ucp
  • ucp来自getcontext, 那么上下文恢复至ucp
  • ucp来自makecontext, 那么将会调用makecontext函数的第二个参数指向的函数func, 如果func返回, 则恢复至ucp->uc_link指定的后继上下文, 如果该ucp中的uc_linkNULL, 那么线程退出;
  • 成功不返回, 失败返回-1, 设置errno。
  1. int swapcontext(ucontext_t *oucp, ucontext_t *ucp);
  • 切换上下文
  • 保存当前上下文至oucp, 恢复ucp上下文(先执行makecontext指定的ucp入口函数, 而后会执行ucp->uc_link指向的后继上下文);
  • 成功不返回, 失败返回-1, 设置errno。
Coroutine类的实现

主要需要实现yield语义和resume语义,即协程主动让出控制权和恢复某个特定的协程。另外,需要定义4种状态,即INIT(初始)、RUNNING(运行)、SUSPEND(挂起)、TERM(结束),便于切换时检查协程的状态。由于协程是在线程内切换的,类似进程拥有一个主线程,线程内也应拥有一个主协程,用线程局部变量(thread_local)定义。

coroutine.h

#ifndef _COROUTINE_H_
#define _COROUTINE_H_
#include 
#include 
#include 

class Coroutine : public std::enable_shared_from_this<Coroutine> {
public:
    typedef std::shared_ptr<Coroutine> ptr;
    typedef std::function<void()> Callback;
    enum State {
        INIT,          // 初始状态
        RUNNING,       // 运行中状态
        SUSPEND,       // 挂起状态
        TERM           // 结束状态
    };
    Coroutine(Callback cb, size_t stacksize = 0);
    
    ~Coroutine();
    uint64_t getId() { return id_; }
    State getState() { return state_; }
    void swapOut();
    void swapIn();
public:
    static void Resume(const Coroutine::ptr crt);
    static void Yield();
    static void SetThis(Coroutine* crt);
    static Coroutine::ptr GetThis();
    static void MainFunc();
private:
    Coroutine();
    uint64_t id_ = 0;
    uint32_t stacksize_ = 0;
    State state_ = INIT;
    ucontext_t ctx_;
    void* stack_ = nullptr; // 协程堆栈
    Callback cb_;

    struct Comparator {
        bool operator()(const Coroutine::ptr& lhs, const Coroutine::ptr& rhs) const;
    };
};
#endif

coroutine.cpp

#include "coroutine.h"
#include 
#include 

static std::atomic<uint64_t> s_id {0};
static uint32_t g_crt_stack_size = 128 * 1024;
static thread_local Coroutine* t_crt = nullptr;
static thread_local Coroutine::ptr t_main_crt = nullptr;

bool Coroutine::Comparator::operator()(const Coroutine::ptr& lhs
                        ,const Coroutine::ptr& rhs) const {
    return lhs->id_ < rhs->id_;
}

Coroutine::Coroutine(Callback cb, size_t stacksize) 
    :id_(++s_id),
     cb_(std::move(cb)) {
    if (getcontext(&ctx_) < 0) {
		std::cerr << "getcontext failed!" << std::endl;
        exit(1);
    }
    stacksize_ = stacksize ? stacksize : g_crt_stack_size;
    stack_ = malloc(stacksize_);
    memset(stack_, 0, stacksize_);
    ctx_.uc_stack.ss_sp = stack_;
    ctx_.uc_stack.ss_size = stacksize_;
    ctx_.uc_link = nullptr;

    makecontext(&ctx_, &Coroutine::MainFunc, 0);
}

Coroutine::Coroutine() {
    state_ = RUNNING;
    SetThis(this);
    if (getcontext(&ctx_) < 0) {
        std::cerr << "getcontext failed!" << std::endl;
        exit(1);
    }
}

Coroutine::~Coroutine() {
    if (stack_) { // work crt
        assert(state_ == TERM || state_ == INIT);
        free(stack_);
    } else { // main crt
        assert(!cb_ && state_ == RUNNING && t_crt == this);
        SetThis(nullptr);
    }
}

Coroutine::ptr Coroutine::GetThis() {
    if (t_crt) return t_crt->shared_from_this();
    Coroutine::ptr main_ctr(new Coroutine);
    assert(t_crt == main_ctr.get());
    t_main_crt = main_ctr;
    return t_crt->shared_from_this();
}

void Coroutine::SetThis(Coroutine* crt) {
    t_crt = crt;
}

void Coroutine::swapOut() {
    SetThis(t_main_crt.get());
    if (swapcontext(&ctx_, &t_main_crt->ctx_) < 0) {
        std::cerr << "swapcontext failed!" << std::endl;
        exit(1);
    }
}

void Coroutine::swapIn() {
    SetThis(this);
    if (swapcontext(&t_main_crt->ctx_, &ctx_) < 0) {
        std::cerr << "swapcontext failed!" << std::endl;
        exit(1);
    }
}
// 挂起到后台
void Coroutine::Yield() {
    Coroutine::ptr cur = GetThis();
    assert(cur->state_ == RUNNING);
    cur->state_ = SUSPEND;
    cur->swapOut();
}
// 恢复到前台
void Coroutine::Resume(const Coroutine::ptr crt) {
    assert(crt->state_ != RUNNING);
    crt->state_ = RUNNING;
    crt->swapIn();
}

void Coroutine::MainFunc() {
    Coroutine::ptr cur = GetThis();
    try {
        cur->cb_();
        cur->cb_ = nullptr;
        cur->state_ = TERM;
    } catch (std::exception& ex) {
        std::cerr << "Coroutine Except: " 
                            << ex.what() << " id = " 
                            << cur->getId() << std::endl;
    } 
    auto raw_ptr = cur.get();
    cur.reset();
    raw_ptr->swapOut();
}

写一个简单的测试程序,测试平均每次切换需要多长时间。
test.cpp

#include "coroutine.h"
#include 
uint64_t GetCurrentUS() {
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return tv.tv_sec * 1000 * 1000ul  + tv.tv_usec;
}

uint64_t cnt = 0;
const uint64_t MAX_CNT = 10000000;
void crt_func() {
    while (true) {
        ++cnt;
        Coroutine::Yield();
    }
}

void test_crt() {
    Coroutine::GetThis();
    Coroutine::ptr crt(new Coroutine(&crt_func));
    uint64_t start = GetCurrentUS();
    uint64_t end;
    while (true) {
       	Coroutine::Resume(crt);
        if (cnt > MAX_CNT) {
            end = GetCurrentUS();
            uint64_t cost = end - start;
            double oneSwitch = (double)cost / MAX_CNT;
            std::cout << "time cost = " << cost << " us" << std::endl;
            std::cout << "oneSwitch = " << oneSwitch * 1000 << " ms" << std::endl;
            break;
        }
    }
}

int main(int argc, char** argv) {
    test_crt();
    return 0;
}

运行结果如下:
ucp_test
可见平均每次切换需要800ms。

2.2 基于Boost的Coroutine

简介
  • boost::coroutines2::coroutine用来实现协程,其中协程之间可以传递类型为T的参数;
  • 协程函数体必须带有pull_typepush_type类型的参数;
  • 其中pull_type可以从push_type那里接收并返回数据,push_type可以把数据传给pull_type。简单来说,如果一个协程需要外部给它传入数据,就把它的函数体参数定义成push_type;如果需要从这个协程返回数据,就把它的函数体参数定义成pull_type
例子

基于Boost的Coroutine类是基于fcontext实现的,也就是fast context,切换非常快,以下写一个测试程序简单测一下:

#include 
#include 

uint64_t cnt = 0;
const uint64_t MAX_CNT = 10000000;
void crt_func(boost::coroutines2::coroutine<void>::push_type & sink){
    while (true) {
        ++cnt;
        sink();
    }
}

void main_func() {
    boost::coroutines2::coroutine<void>::pull_type source(crt_func);
    uint64_t start = GetCurrentUS();
    uint64_t end;
    while (true) {
        source();
        if (cnt > MAX_CNT) {
            end = GetCurrentUS();
            uint64_t cost = end - start;
            double oneSwitch = (double)cost / MAX_CNT;
            std::cout << "time cost = " << cost << std::endl;
            std::cout << "oneSwitch = " << oneSwitch * 1000 << " ms" << std::endl;
            break;
        }
    }
}
int main(int argc, char* argv[]) {
    main_func();
    return 0;
}

运行结果如下:
fct_test
可见单次切换非常快,才12 ms!!!

一个简单的排列数生成器

由于我需要从协程中返回下一个排列,所以我把协程函数体的参数定义为pull_type,即void crt_func(CrtType::pull_type & sink);。相对应的,主协程source的类型应为push_type。

#include 
#include 
#include 
#include 
#include 

uint64_t cnt = 0;
const uint64_t MAX_CNT = 10000000;
void printVec(const std::vector<int>& vec) {
    for (int v : vec) 
        std::cout << v << " ";
    std::cout << std::endl;
}
using std::placeholders::_1;
class PermGen {
public:
    typedef boost::coroutines2::coroutine<std::vector<int>&> CrtType;
    PermGen(const std::vector<int>& vec) 
    :datas(vec) {
        source.reset(new CrtType::push_type(PermGen::crt_func));
    }
    std::vector<int> next() {
        (*source)(datas);
        return datas;
    }
private:
    static void nextPermutation(std::vector<int>& nums) {
        int i = nums.size() - 2;
        while (i >= 0 && nums[i + 1] <= nums[i]) 
            --i;
        if (i >= 0) {
            int j = nums.size() - 1;
            while (j >= 0 && nums[j] <= nums[i]) 
                --j;
            std::swap(nums[i], nums[j]);
        }
        std::reverse(nums.begin() + i + 1, nums.end());
    }
    static void crt_func(CrtType::pull_type & sink) {
        while (true) {
            std::vector<int>& temp = sink.get();
            nextPermutation(temp);
            sink();
        }
    }
private:
    std::unique_ptr<CrtType::push_type> source;
    std::vector<int> datas;
};

int main(int argc, char* argv[]) {
    std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    PermGen perm(nums);
    std::vector<int> res(nums);
    do {
        printVec(res);
        res = perm.next();
    } while (res != nums);
    return 0;
}

运行结果如下:
C++协程入门_第1张图片

3 总结

当前协程已经成为一个比较流行的高并发IO基础技术,后续会深入了解ucontext、boost fcontext还有腾讯的libco的底层汇编代码,并尝试着利用协程去开发异步IO。结合网络编程,整合成高性能网络库。

你可能感兴趣的:(编程,linux,c++)