简介
本文使用 C++20 引入的协程来编写一个 Linux epoll 程序。在此实现中,用户使用异步操作时再也无需提供自己的回调函数。以此处实现的 asyncRead()
为例:
- 使用
asyncRead()
所需的参数和read()
大致相同,无需传入回调; asyncRead()
的内部会向 epoll 注册要监听的文件描述符、感兴趣的事件和要执行的回调(由实现提供,而无需使用者传入);- 当事件未就绪时,
co_await asyncRead()
会挂起当前协程; - 当事件就绪时,epoll 循环中会执行具体的 I/O 操作(此处将其提交到 I/O 线程池中执行),当 I/O 操作完成时,恢复协程的运行。
1. ThreadPool
此处使用了两个线程池:
- I/O 线程池:用于执行 I/O 操作;
- 任务线程池:用于处理客户端连接(此处以 tcp 回显程序为例)。
此处使用的是自己实现的线程池,具体实现见 https://segmentfault.com/a/11...。
2. IOContext
IOContext
类对 Linux epoll 做了简单的封装。
io_context.h:
#ifndef IOCONTEXT_H
#define IOCONTEXT_H
#include
#include
#include
#include
#include "thread_pool.h"
using callback_t = std::function;
struct Args
{
callback_t m_cb;
};
class IOContext
{
public:
IOContext(int nIOThreads=2, int nJobThreads=2);
// 监听文件描述符 fd,感兴趣事件为 events,args 中包含要执行的回调
bool post(int fd, int events, const Args& args);
// 提交任务至任务线程池
bool post(const Task& task);
// 不再关注文件描述符 fd,并移除相应的回调
void remove(int fd);
// 持续监听、等待事件就绪
void run();
private:
int m_fd;
std::unordered_map m_args;
std::mutex m_lock;
ThreadPool m_ioPool; // I/O 线程池
ThreadPool m_jobPool; // 任务线程池
};
#endif
io_context.cpp:
#include "io_context.h"
#include
#include
#include
IOContext::IOContext(int nIOThreads, int nJobThreads)
: m_ioPool(nIOThreads), m_jobPool(nJobThreads)
{
m_fd = epoll_create(1024);
}
bool IOContext::post(int fd, int events, const Args& args)
{
struct epoll_event event;
event.events = events;
event.data.fd = fd;
std::lock_guard lock(m_lock);
int err = epoll_ctl(m_fd, EPOLL_CTL_ADD, fd, &event);
if (err == 0)
{
m_args[fd] = args;
}
return err == 0;
}
bool IOContext::post(const Task& task)
{
return m_jobPool.submitTask(task);
}
void IOContext::remove(int fd)
{
std::lock_guard lock(m_lock);
int err = epoll_ctl(m_fd, EPOLL_CTL_DEL, fd, nullptr);
if (err == 0)
{
m_args.erase(fd);
}
else
{
std::cout << "remove: " << strerror(errno) << "\n";
}
}
void IOContext::run()
{
int timeout = -1;
size_t nEvents = 32;
struct epoll_event* eventList = new struct epoll_event[nEvents];
while (true)
{
int nReady = epoll_wait(m_fd, eventList, nEvents, timeout);
if (nReady < 0)
{
delete []eventList;
return;
}
for (int i = 0; i < nReady; i++)
{
int fd = eventList[i].data.fd;
m_lock.lock();
auto cb = m_args[fd].m_cb;
m_lock.unlock();
remove(fd);
m_ioPool.submitTask([=]()
{
cb();
});
}
}
}
3. Awaitable
实现 C++ 协程所需的类型,详细解释见 https://segmentfault.com/a/11...。
awaitable.h:
#ifndef AWAITABLE_H
#define AWAITABLE_H
#include
#include
#include
#include
#include "io_context.h"
// 回调需要执行的操作类型:读、写、接受客户端连接
enum class HandlerType
{
Read, Write, Accept,
};
class Awaitable
{
public:
Awaitable(IOContext* ctx, int fd, int events, void* buf, size_t n, HandlerType ht);
bool await_ready();
void await_suspend(std::coroutine_handle<> handle);
int await_resume();
private:
IOContext* m_ctx;
int m_fd;
int m_events;
void* m_buf;
size_t m_n;
int m_result;
HandlerType m_ht;
};
struct CoroRetType
{
public:
struct promise_type
{
CoroRetType get_return_object();
std::suspend_never initial_suspend();
std::suspend_never final_suspend() noexcept;
void return_void();
void unhandled_exception();
};
};
#endif
awaitable.cpp:
#include
#include "awaitable.h"
Awaitable::Awaitable(IOContext *ctx, int fd, int events, void* buf, size_t n, HandlerType ht)
: m_ctx(ctx), m_fd(fd), m_events(events), m_buf(buf), m_n(n), m_ht(ht)
{}
bool Awaitable::await_ready()
{
return false;
}
int Awaitable::await_resume()
{
return m_result;
}
// 注册要监听的文件描述符、感兴趣的事件及要执行的回调
void Awaitable::await_suspend(std::coroutine_handle<> handle)
{
auto cb = [handle, this]() mutable
{
switch (m_ht)
{
case HandlerType::Read:
m_result = read(m_fd, m_buf, m_n);
break;
case HandlerType::Write:
m_result = write(m_fd, m_buf, m_n);
break;
case HandlerType::Accept:
m_result = accept(m_fd, nullptr, nullptr);
break;
}
handle.resume();
};
Args args{cb};
m_ctx->post(m_fd, m_events, args);
}
CoroRetType CoroRetType::promise_type::get_return_object()
{
return CoroRetType();
}
std::suspend_never CoroRetType::promise_type::initial_suspend()
{
return std::suspend_never{};
}
std::suspend_never CoroRetType::promise_type::final_suspend() noexcept
{
return std::suspend_never{};
}
void CoroRetType::promise_type::return_void()
{}
void CoroRetType::promise_type::unhandled_exception()
{
std::terminate();
}
4. 异步操作
使用协程来封装异步操作。
io_util.h:
#ifndef IO_UTIL_H
#define IO_UTIL_H
#include "io_context.h"
#include "awaitable.h"
Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n);
Awaitable asyncAccept(IOContext* ctx, int fd);
#endif
io_util.cpp:
#include "io_util.h"
Awaitable asyncRead(IOContext* ctx, int fd, void* buf, size_t n)
{
return Awaitable(ctx, fd, EPOLLIN, buf, n, HandlerType::Read);
}
Awaitable asyncWrite(IOContext* ctx, int fd, void* buf, size_t n)
{
return Awaitable(ctx, fd, EPOLLOUT, buf, n, HandlerType::Write);
}
Awaitable asyncAccept(IOContext* ctx, int fd)
{
return Awaitable(ctx, fd, EPOLLIN, nullptr, 0, HandlerType::Accept);
}
5. 例子
main.cpp:
#include
#include
#include
#include
#include
#include
#include
#include "io_util.h"
static std::mutex ioLock;
static uint16_t port = 6666;
static int backlog = 32;
static const char* Msg = "hello, cpp!";
static const size_t MsgLen = 11;
static IOContext ioContext;
CoroRetType handleConnection(int fd)
{
char buf[MsgLen+1] = {0};
int n;
n = co_await asyncRead(&ioContext, fd, buf, MsgLen);
buf[n+1] = '\0';
co_await asyncWrite(&ioContext, fd, buf, n);
close(fd);
}
CoroRetType serverThread()
{
int listenSock = socket(AF_INET, SOCK_STREAM, 0);
int value = 1;
setsockopt(listenSock, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(int));
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_port = htons(port);
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
int err = bind(listenSock, (const struct sockaddr*)&addr, sizeof(addr));
listen(listenSock, backlog);
while (true)
{
int clientSock = co_await asyncAccept(&ioContext, listenSock);
auto h = [=]()
{
handleConnection(clientSock);
};
ioContext.post(h);
}
}
void clientThread()
{
using namespace std::literals;
std::this_thread::sleep_for(1s);
int sock = socket(AF_INET, SOCK_STREAM, 0);
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_port = htons(port);
addr.sin_family = AF_INET;
inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);
connect(sock, (const struct sockaddr*)&addr, sizeof(addr));
char buf[MsgLen+1] = {0};
ssize_t n = write(sock, Msg, MsgLen);
read(sock, buf, n);
buf[n+1] = '\0';
std::lock_guard lock(ioLock);
std::cout << "clientThread: " << buf << '\n';
close(sock);
}
int main()
{
serverThread();
constexpr int N = 10;
for (int i = 0; i < N; i++)
{
std::thread t(clientThread);
t.detach();
}
ioContext.run();
}
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!
clientThread: hello, cpp!