简易协程-3

新增的功能

在《简易协程-2》的基础上增加协程同步等待、IO超时的支持。
增加一个新类JoinAction支持协程同步等待,yield这个类的对象会让协程进入等待状态,直到目标协程退出或者超时。使用示例如下。

# 生成并运行另外一个协程c
c = cf1()
Scheduler.add(c)
 t1 = time()
# 等待c完成,超时时间0.5秒,结果是is_timeout ,如果为True则表示等待超时了
 is_timeout = yield JoinAction(c, 0.5)

IO超时的实现是在SocketIO中增加了超时时间参数,单位也是秒。如果请求的事件未能在给定时间到达,则调度器会在协程内抛出一个异常。示例如下。

# require to write data, timeout is 5s
yield SocketIO(sock.fileno(), read=False, timeout=5)
sock.send("data")
yield SocketIO(sock.fileno(), read=True, timeout=5)
data = sock.recv(1024)

完整代码

以下是详细代码。

#!/usr/bin/env python
# coding: utf-8

from collections import deque
from errno import ETIMEDOUT
from heapq import heappop
from heapq import heappush
from itertools import chain
from select import select
from socket import timeout as SocketTimeoutError
from sys import exc_info
from sys import maxint
from time import sleep
from time import time
from types import GeneratorType


class Sleep(object):
    __slots__ = ["seconds", ]

    def __init__(self, seconds):
        # type: (float) -> object
        self.seconds = seconds
        # assert seconds >= 0


class SocketIO(object):
    __slots__ = ["sock_fd", "read", "timeout"]

    def __init__(self, sock_fd, read=True, timeout=-1):
        self.sock_fd = sock_fd
        self.read = read
        self.timeout = timeout


class JoinAction(object):
    __slots__ = ["target_coroutine", "timeout"]

    def __init__(self, target_coroutine, timeout=-1.0):
        # type: (GeneratorType, float) -> JoinAction
        """

        :param target_coroutine: target generator
        :param timeout: seconds
        """
        self.target_coroutine = target_coroutine
        self.timeout = timeout


class Coroutine(object):
    __slots__ = ["generator", "parent", "init_value", "exception_info", "name"]

    def __init__(self, generator, parent=None, init_value=None, exception_info=(), name=""):
        # type: (GeneratorType, Coroutine, object, tuple) -> Coroutine
        self.generator = generator
        self.parent = parent
        self.init_value = init_value
        self.exception_info = exception_info
        self.name = name
        if not name:
            self.name = generator.gi_code.co_name

    def __str__(self):
        return "%s.%s" % (self.name, self.cid())

    __repr__ = __str__

    def cid(self):
        return id(self.generator)

    def reset_input(self, value=None, exception_info=()):
        self.init_value = value
        self.exception_info = exception_info

    def run(self):
        if self.exception_info:
            value = self.generator.throw(*self.exception_info)
            self.exception_info = ()
        else:
            value = self.generator.send(self.init_value)
        self.init_value = value
        return value


class CoroutineError(Exception):
    pass


class FakeSocket(object):
    __slots__ = ["data"]

    def __init__(self):
        self.data = ""

    def fileno(self):
        return id(self)

    def send(self, data):
        self.data = data
        return len(data)

    def recv(self, _):
        return "HTTP/1.1 200 OK\r\nContent-Length:0\r\n\r\n"


from random import random

next_time = {}


def fake_select(rlist, wlist, xlist, timeout):
    rxlist = list(rlist)
    wxlist = list(wlist)
    return rxlist, wxlist, []


WAIT_CANCELED = 0
WAIT_SOCKET = 1
WAIT_JOIN = 2
WAIT_SLEEP = 3


class _TimeoutItem(object):
    def __init__(self, till, wait_type, arg):
        # type: (int, int, object) -> _TimeoutItem
        self.till = till
        self.wait_type = wait_type
        self.arg = arg
        self.id = id(self)


class Scheduler(object):
    _instance = None

    def __init__(self, ignore_exception=True, debug=False):
        """

        :param debug: output running detail
        :param ignore_exception: ignore coroutine's uncaught exception
        """
        self.ignore_exception = ignore_exception
        self.debug = debug
        # if true, append debug logs to _debug_logs; else, print them to stdout
        self.collect_debug_logs = False
        self._debug_logs = []
        # use fake_select() to test performance or simulation, work with FakeSocket
        self.use_fake_select = False
        #
        self.start_time = time()

        # map coroutine_id => coroutine
        self.cid2coroutine = {}
        # running queue
        self.queue = deque()
        # map: sock_fd -> [coroutine, timeout_item]
        self.sock_map = {}
        self.io_read_queue = set()
        self.io_write_queue = set()
        # map: coroutine_id-> waiters {waiter_coroutine_it->timeout_item, ...}, by join wait
        self.waiting_map = {}
        # [timeout_item, ...]
        # map: millisecond (int) -> dict(item_id -> timeout_item)
        self.timer_slots_map = {}
        # [ms1, ms2, ...]
        self.millisecond_heap = []
        #
        self.alive_coroutine_num = 0
        # current running coroutine
        self.current = None
        # whether run() is calling
        self.running = False

    @classmethod
    def get_instance(cls):
        # type: () -> Scheduler
        if not cls._instance:
            cls._instance = cls()
        return cls._instance

    def _debug_output(self, msg, *args):
        if self.debug:
            if self.collect_debug_logs:
                self._debug_logs.append(("%.6f" % time(), msg % args))
            else:
                print "%.6f" % time(), msg % args
        else:
            pass

    def _add(self, generator):
        co = Coroutine(generator)
        cid = co.cid()
        self.cid2coroutine[cid] = co
        self.alive_coroutine_num += 1
        self._debug_output("add new coroutine %d, alive_coroutine_num=%d",
                           cid, self.alive_coroutine_num)
        self.queue.append(co)
        return self

    def _coroutine_exit(self, coroutine, is_error):
        # type: (Coroutine, bool) -> object
        cid = coroutine.cid()
        assert cid in self.cid2coroutine
        parent = coroutine.parent
        if parent is None:
            # wake up all waiters or cancel io wait timeout
            waiters = self.waiting_map.pop(cid, None)
            if waiters:
                assert isinstance(waiters, dict)
                # join wait
                self._debug_output("%s wake up %d waiters", cid, len(waiters))
                for wcid, timeout_item in waiters.iteritems():
                    waiter = self.cid2coroutine[wcid]
                    waiter.reset_input(False)
                    self.queue.append(waiter)
                    # invalid timeout_item
                    self.timer_slots_map[timeout_item.till].pop(timeout_item.id)
                del waiters
            self.alive_coroutine_num -= 1
        else:
            if is_error:
                parent.reset_input(None, exc_info())
            else:
                parent.reset_input(coroutine.init_value, ())
            self.queue.append(parent)

        self.cid2coroutine.pop(cid)
        self._debug_output("coroutine %d exited, alive_coroutine_num=%d", cid, self.alive_coroutine_num)

    def _current_coroutine(self):
        # type: () -> Coroutine
        return self.current

    @classmethod
    def current_id(cls):
        return cls.get_instance()._current_coroutine().cid()

    @classmethod
    def current_name(cls):
        return cls.get_instance()._current_coroutine().name

    def _add_timeout(self, seconds, wait_type, arg):
        # type: (float, int, object) -> _TimeoutItem
        till = int(1000 * (time() - self.start_time + seconds + 0.0005)) if seconds >= 0 else maxint
        self._debug_output('coroutine add a timeout task at %sms from start', till)
        timeout_item = _TimeoutItem(till, wait_type, arg)
        # insert new item
        if till in self.timer_slots_map:
            self.timer_slots_map[till][timeout_item.id] = timeout_item
        else:
            self.timer_slots_map[till] = {timeout_item.id: timeout_item}
            heappush(self.millisecond_heap, till)
        return timeout_item

    def _do_coroutine_io(self, coroutine, event):
        # type: (Coroutine, SocketIO) -> object
        coroutine.reset_input()
        sock_fd = event.sock_fd
        if event.read:
            self.io_read_queue.add(sock_fd)
        else:
            self.io_write_queue.add(sock_fd)
        timeout_item = self._add_timeout(event.timeout, WAIT_SOCKET, sock_fd)
        self.sock_map[sock_fd] = [coroutine, timeout_item]

    def _do_coroutine_sleep(self, coroutine, seconds):
        coroutine.reset_input()
        timeout_item = self._add_timeout(seconds, WAIT_SLEEP, coroutine)
        self._debug_output('coroutine go to sleep until %s', timeout_item.till)

    def _do_coroutine_join(self, coroutine, event):
        # type: (Coroutine, JoinAction) -> None
        target_cid = id(event.target_coroutine)
        timeout = event.timeout
        cid = coroutine.cid()
        if cid == target_cid:
            try:
                raise CoroutineError("can't join self")
            except CoroutineError:
                coroutine.reset_input(None, exc_info())
                self.queue.append(coroutine)
        elif target_cid not in self.cid2coroutine:
            # target coroutine exited, join action ends
            coroutine.reset_input(False)
            self.queue.append(coroutine)
        elif 0 <= timeout < 0.001:
            # timeout too small, so just tell coroutine he is timeout
            coroutine.reset_input(True)
            self.queue.append(coroutine)
        else:
            self._debug_output("coroutine %s try to join %s, timeout=%f",
                               cid, target_cid, timeout)
            timeout_item = self._add_timeout(timeout, WAIT_JOIN, (cid, target_cid))
            if target_cid in self.waiting_map:
                self.waiting_map[target_cid][cid] = timeout_item
            else:
                self.waiting_map[target_cid] = {cid: timeout_item}

    # noinspection PyBroadException
    def _process_running_queue(self):
        old_queue = self.queue
        self.queue = deque()
        append = self.queue.append
        for coroutine in old_queue:
            self.current = coroutine
            # assert isinstance(coroutine, Coroutine)
            try:
                value = coroutine.run()
            except StopIteration:
                self._coroutine_exit(coroutine, False)
                continue
            except:
                self._coroutine_exit(coroutine, True)
                if coroutine.parent is None and not self.ignore_exception:
                    self._debug_output("%s raise uncaught exception", coroutine.cid())
                    raise
                else:
                    continue
            if value is None:
                # yield to other coroutines
                append(coroutine)
            elif isinstance(value, GeneratorType):
                sub = Coroutine(value, coroutine)
                append(sub)
                self.cid2coroutine[sub.cid()] = sub
            elif isinstance(value, SocketIO):
                self._do_coroutine_io(coroutine, value)
            elif isinstance(value, Sleep):
                self._do_coroutine_sleep(coroutine, value.seconds)
            elif isinstance(value, JoinAction):
                self._do_coroutine_join(coroutine, value)
            else:
                # this coroutine exit
                self._coroutine_exit(coroutine, False)
        self.current = None

    def _process_sleep_queue(self):
        now = time()
        from_start_ms = int(1000 * (now - self.start_time))
        millisecond_heap = self.millisecond_heap
        while millisecond_heap:
            # check recent till millisecond time
            till = heappop(millisecond_heap)
            # get all timeout tasks in this millisecond
            item_map = self.timer_slots_map.pop(till)
            if till > from_start_ms:
                # there are some tasks in this millisecond, so loop ends
                if item_map:
                    self.timer_slots_map[till] = item_map
                    heappush(self.millisecond_heap, till)
                    return min(1.0, 0.001 * (till - from_start_ms))
                else:
                    # no task, continue to next millisecond
                    continue
            # do time out tasks
            assert isinstance(item_map, dict)
            for timeout_item in item_map.itervalues():
                assert isinstance(timeout_item, _TimeoutItem)
                wait_type = timeout_item.wait_type
                if wait_type is WAIT_CANCELED:
                    continue
                assert timeout_item.till == till
                arg = timeout_item.arg
                if wait_type is WAIT_JOIN:
                    # join time out
                    waiting_cid, target_cid = arg
                    waiters = self.waiting_map[target_cid]
                    assert isinstance(waiters, dict)
                    self._debug_output("coroutine %s join %s time out", waiting_cid, target_cid)
                    del waiters[waiting_cid]
                    if not waiters:
                        del self.waiting_map[target_cid]
                    # wake up this waiter
                    waiter = self.cid2coroutine[waiting_cid]
                    # true: really timeout
                    waiter.reset_input(True)
                    self.queue.append(waiter)
                    self._debug_output("%s timeout on join", waiter)
                elif wait_type is WAIT_SOCKET:
                    # io time out
                    sock_fd = arg
                    self._debug_output("socket %s io timeout", sock_fd)
                    # sock_fd already never listen for events
                    if sock_fd not in self.sock_map:
                        continue
                    # un-register event watch
                    self.io_read_queue.discard(sock_fd)
                    self.io_write_queue.discard(sock_fd)
                    # find the owner coroutine of this sock_fd
                    coroutine, timeout_item = self.sock_map[sock_fd]
                    assert isinstance(coroutine, Coroutine)
                    # owner maybe already exited
                    if coroutine.cid() not in self.cid2coroutine:
                        continue
                    try:
                        raise SocketTimeoutError(ETIMEDOUT, "timeout")
                    except SocketTimeoutError:
                        # raise exception to this coroutine
                        coroutine.reset_input(None, exc_info())
                        self.queue.append(coroutine)
                        self._debug_output("%s timeout on socket", coroutine)

                else:
                    # sleep type, arg is sleeping coroutine. sleep is reached, so wake up this coroutine
                    assert wait_type is WAIT_SLEEP
                    assert isinstance(arg, Coroutine)
                    self.queue.append(arg)
                    self._debug_output("%s wake up from sleep", arg)
            del item_map
        return 0.0

    def _process_io(self, sleep_seconds):
        io_read_queue = self.io_read_queue
        io_write_queue = self.io_write_queue
        queue_append = self.queue.append
        if self.use_fake_select:
            rxlist, wxlist, exlist = fake_select(io_read_queue,
                                                 io_write_queue, [],
                                                 sleep_seconds)
        else:
            rxlist, wxlist, exlist = select(io_read_queue,
                                            io_write_queue, [],
                                            sleep_seconds)
        # collect coroutines waiting for these sockets
        io_read_queue -= set(rxlist)
        io_write_queue -= set(wxlist)
        if exlist:
            exset = set(exlist)
            io_read_queue -= exset
            io_write_queue -= exset

        # wake coroutines
        for sock_fd in chain(rxlist, wxlist, exlist):
            self._debug_output("socket %s become ready", sock_fd)
            coroutine, timeout_item = self.sock_map[sock_fd]
            cid = coroutine.cid()
            assert cid in self.cid2coroutine
            queue_append(coroutine)
            # try to cancel io timeout item
            assert timeout_item.wait_type is WAIT_SOCKET
            self.timer_slots_map[timeout_item.till].pop(timeout_item.id)

    def _run(self):
        if self.running:
            raise CoroutineError("already running")
        self.running = True
        self._debug_logs = []

        io_read_queue = self.io_read_queue
        io_write_queue = self.io_write_queue
        # start to run all coroutines until all exited
        while self.alive_coroutine_num > 0:
            self._process_running_queue()
            sleep_seconds = self._process_sleep_queue()
            if io_read_queue or io_write_queue:
                if self.queue:
                    # print "queue is not empty, io timeout set 0"
                    sleep_seconds = 0
                elif sleep_seconds > 1:
                    sleep_seconds = 1
                self._process_io(sleep_seconds)
            elif sleep_seconds > 0 and not self.queue and self.millisecond_heap:
                # sleep_seconds += 0.0003
                self._debug_output("try to sleep for %.6fs", sleep_seconds)
                sleep(sleep_seconds)
                self._debug_output("wake up at %.6f", time())
        # ended
        self.running = False
        assert not self.queue
        assert not self.io_read_queue
        assert not self.io_write_queue
        assert not self.millisecond_heap
        assert not self.timer_slots_map
        assert not self.waiting_map
        assert not self.cid2coroutine
        assert self.current is None
        self.sock_map.clear()

    @classmethod
    def add(cls, coroutine):
        return cls.get_instance()._add(coroutine)

    @classmethod
    def add_many(cls, coroutine_list):
        """
        add many coroutines to scheduler

        :param coroutine_list:  coroutine array
        :return: scheduler
        """
        for coroutine in coroutine_list:
            cls.get_instance()._add(coroutine)
        return cls.get_instance()

    @classmethod
    def run(cls):
        return cls.get_instance()._run()

    @classmethod
    def set_debug(cls, debug=True, collect_logs=False):
        cls.get_instance().debug = debug
        cls.get_instance().collect_debug_logs = collect_logs

    @classmethod
    def get_debug_logs(cls):
        return cls.get_instance()._debug_logs

    @classmethod
    def set_use_fake_select(cls, use_fake_select=True):
        cls.get_instance().use_fake_select = use_fake_select


def async_urlopen(sock, url, method="GET", headers=(), data=""):
    """
    async HTTP request

    :param sock:
    :param url:
    :param method:
    :param headers: (head, value) headers list
    :param data:
    :return response: (code, reason, headers, body)
    """
    pieces = [method, ' ', url, ' HTTP/1.1\r\n', ]
    for head, val in headers:
        pieces.extend((head, ':', val, '\r\n'))
    pieces.extend(('Content-Length:', str(len(data)), '\r\n'))
    pieces.append('Connection: keep-alive\r\n\r\n')
    pieces.append(data)
    req_bin = ''.join(pieces)
    while req_bin:
        yield SocketIO(sock.fileno(), read=False)
        sent = sock.send(req_bin)
        req_bin = req_bin[sent:]
    resp_bin = ""
    resp_len = -1
    code = 400
    reason = "bad request"
    while resp_len != len(resp_bin):
        yield SocketIO(sock.fileno(), read=True)
        data = sock.recv(32 << 10)
        if resp_len > 0:
            resp_bin += data
        else:
            resp_bin += data
            parts = resp_bin.split('\r\n\r\n', 1)
            if len(parts) != 2:
                continue
            head_bin, resp_bin = parts
            lines = head_bin.split('\r\n')
            status_line = lines[0]
            version, code, reason = status_line.split(' ', 2)
            code = int(code)
            headers = [line.split(':', 1) for line in lines[1:-1]]
            if method == 'HEAD':
                break
            resp_len = 0
            for head, val in headers:
                if head.lower() == 'content-length':
                    resp_len = int(val)
                    break
    yield (code, reason, headers, resp_bin)

超时的实现原理

超时用于三个功能:休眠、IO超时、JoinAction超时。这三者具有一定的相似性,都需要计算一段时间,到达指定时间再用不同的方式处理。归结到一起就是,都需要创建一个一次性定时任务。到达指定时间后,对于休眠任务则唤醒协程,加入到可运行队列;对于IO任务,则唤醒协程,产生超时异常给协程;对于JoinAction任务,则唤醒等待的协程,并用超时结果传递给这个协程。后两者有一点不同的地方是,这两处的定时任务可能中途会被取消。如果IO及时到达,超时任务必须取消。如果目标协程及时退出,JoinAction超时任务也必须取消。
为了简化实现,超时任务的精度只取到毫秒级,这样就可以用整数来表示毫秒。
先说一下设计的主要的数据结构timer_slots_map和millisecond_heap。
millisecond_heap如名字所示,是一个毫秒整数的小根堆。毫秒数是当前时间减去进程启动时间的毫秒时间。每个数字表示这个时间段内可能存在着超时任务。示例如下。

+++++++++++++
|100|200|280|
+++++++++++++

堆中有三个时间,100、200、280,也就是说明,在99-100毫秒、199-200毫秒、279-280毫秒这些时间段可能存在超时任务。
使用堆这种数据结构,我们可以快速的得到最近的时间、快速的插入新的时间。由于堆的结构自身的高效性以及python使用c语言的实现,所以即使长度很大,添加、删除的耗时依然会很小。
接下来是timer_slots_map,这是一个稍微复杂的数据结构,功能是保存所有定时任务。这个结构的第一级是一个字典,毫秒时间映射到对应的定时任务列表。定时任务列表也是一个字典结构,每个定时任务用timeout_item表示,则列表的映射方式是id(timeout_item) -> timeout_item。
以下是一个timer_slots_map的示例。

{
    100 => { id1 => timeout_item1 , id2 => timeout_item2 },
    280 => { id3 => timeout_item3},
    200 => {}
}

如上所示,有三个时间点有定时任务,其中100这个时间有两个任务,而200这个时间点则没有,这是个正常的现象,当定时任务取消时就会出现。
现在来说一下几个需要实现的定时任务接口:

  1. 增加定时任务
  2. 取消定时任务
  3. 获取时间最近的定时任务
1 增加定时任务

功能就是将定时任务timeout_item加入到队列中,定时任务包含具体的类型、参数等,这里我们只关注时间。
首先是计算时间,可以得到一个毫秒整数till。检查till是否在timer_slots_map中,如果是,则till必然已经在millisecond_heap中,否则需要追加到millisecond_heap尾部,使用heappush()自动维护堆的结构。最后就是将timeout_item插入到timer_slots_map[till]这个字典中。

timer_slots_map[till][id(timeout_item)] = timeout_item
2. 取消定时任务

输入参数timeout_item。
首先根据这个定时任务计算超时时间till,再从timer_slots_map[till]这个字典中删除timeout_item。由于我们采用timeout_item的id作为键,所以只需要用timeout_item的id删除即可。这实际上也就是要,这个删除的timeout_item必须是先前增加定时任务使用的对象。

del timer_slots_map[till][id(timeout_item)]

注意到一点,添加的时候millisecond_heap可能加入了till,但是删除的时候,却没有从millisecond_heap删除till这个时间。这么做是有原因的,堆本质是数组,从数组中间删除元素的代价是很大的。保留till在原处并不会影响多大,而且由于我们采用的是毫秒为时间,这也就限制了millisecond_heap的长度。如果采用的精确的双精度表示时间,则millisecond_heap必然会膨胀到无法承受的长度。

3. 获取时间最近的定时任务

millisecond_heap是小根堆,第一个元素就是最近的时间。使用heappop()函数可以方便的从millisecond_heap弹出首个时间,再根据这个时间去timer_slots_map查找对应的定时任务列表。

总结

使用堆和字典两个数据结构,高效而简洁的实现了定时任务。
在IO很多的时候,定时任务可能会快速增加,为了减少millisecond_heap的长度,可以将这个超时时间取整到如10毫秒甚至100毫秒。

JoinAction的实现原理

主要依赖的数据结构是waiting_map。这是一个字典结构,键是协程id,值是等待这个协程的所有协程列表,这是一个字典结构,键是协程id,值是定时任务。
示例如下。

{
  c1 => { waiter1 => timeout_item1, waiter2 => timeout_item2 },
  c2 => { waiter3 => timeout_item3, waiter4 => timeout_item4 }
}

waiter1 和 waiter2 都在等待协程c1,并分别设有超时任务。
当协程c1退出时,遍历c1对应的等待列表,唤醒所有等待协程,删除超时 任务。

你可能感兴趣的:(简易协程-3)