vLLM (3) - Sequence & SequenceGroup

系列文章目录

vLLM (1) - Qwen2推理&部署
vLLM (2) - 架构总览
vLLM (3) - Sequence & SequenceGroup


文章目录

  • 系列文章目录
  • 前言
  • 一、SequenceStage & SequenceStatus
    • 1.SeqenceStage
    • 2.SeqenceStatus
  • 二、SequenceData & Sequence
    • 1.SequenceData
    • 2.Sequence
  • 三、SequenceGroup & SequenceGroupMetadata
    • 1.SequenceGroup
    • 2.SequenceGroupMetadata
  • 总结


前言

一个request请求通常对应一个或者多个序列,在vllm中,使用Sequence来表达一个序列,同一个请求中的所有序列构成了一个序列组,用SequenceGroup来表达。尽管它们本身并不难懂,但是涉及到了序列的状态标记(是WAITING还是FINISHED)、所处阶段(是Prefill还是Decode)以及对应的逻辑块Logical blocks等等。理解它们,对后续分析vllm调度以及block的分配等问题有着重要作用。

另外,本系列之前没有对Prefill & Decode(过程如下图所示)和KV Cache做出过比较正式或者深入的介绍。如果对他们不了解或者还有困惑,非常建议查看GLM-4 (6) KV Cache / Prefill & Decode这篇文章,因为它们也是理解vllm的重要前提。
vLLM (3) - Sequence & SequenceGroup_第1张图片


一、SequenceStage & SequenceStatus

1.SeqenceStage

SeqenceStage表示的是一个序列的状态,一般情况下,在计算prompt的时候,对应的就是预填充PREFILL;当模型生成时,也就是需要采样sample的时候,对应的就是DECODE

class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()

2.SeqenceStatus

SequenceStatus表示的是一个序列的状态。

  1. WAITING:等待状态,还没有做prefill,该开始时vllm会尽可能把所有requests扔到waiting队列中;
  2. RUNNING:已在开始推理,并且尚未结束;
  3. SWAPPED:交换状态,由于KV cache是动态分配给不同的request的,当生成的token越来越多时,显存压力也越来越大(可用的block越来越少),当预测到下一时刻显存不够用时,优先级低的sequenceblock会被抢占,它的数据可能会被搬到cpu上,此时状态为SWAPPED;从gpu搬到cpu称为swap out,反之则称为swap in
  4. FINISHED:解码完成;但是由于不同的原因,又细分成4种:a)FINISHED_STOPPED:解码完成;b)FINISHED_LENGTH_CAPPED:生成到指定长度被截断;c)FINISHED_ABORTED:主动中止;d)FINISHED_IGNOREDprompt长度已经超过最长长度,直接忽略。
class SequenceStatus(enum.Enum):
    """Status of a sequence."""
    WAITING = enum.auto()
    RUNNING = enum.auto()
    SWAPPED = enum.auto()
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
            SequenceStatus.FINISHED_ABORTED,
            SequenceStatus.FINISHED_IGNORED,
        ]

    @staticmethod
    def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
        if status == SequenceStatus.FINISHED_STOPPED:
            finish_reason = "stop"
        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
            finish_reason = "length"
        elif status == SequenceStatus.FINISHED_ABORTED:
            finish_reason = "abort"
        elif status == SequenceStatus.FINISHED_IGNORED:
            # The ignored sequences are the sequences whose prompt lengths
            # are longer than the model's length cap. Therefore, the stop
            # reason should also be "length" as in OpenAI API.
            finish_reason = "length"
        else:
            finish_reason = None
        return finish_reason

二、SequenceData & Sequence

1.SequenceData

SequenceData是一个专门管理序列数据的类。如下面代码所示,记录的数据包括:

  • self.prompt_token_idsself.output_token_ids
  • self._num_computed_tokens:已经被计算过的prefill tokens的数量,这边的prefill并不仅仅是指prompt的部分,正如self.get_num_uncomputed_tokens()部分所说,重计算的情况对应的是prompt + output
  • self._stage:刚开始处于预填充阶段,就是SequenceStage.PREFILL;当预填充计算完成,也就是self._num_uncompleted_tokens()==0时,应当准备后续的deocde操作,也就是进入SequenceStage.DECODE阶段,如self.update_num_computed_tokens()中代码所示。
class SequenceData:
    """Data associated with a sequence."""

    def __init__(
        self,
        prompt_token_ids: List[int],    # prompt tokens
        output_token_ids: Optional[List[int]] = None,  # 生成的tokens
    ) -> None:
        if output_token_ids is None:
            output_token_ids = []

        self.prompt_token_ids = prompt_token_ids
        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
        self.output_token_ids = output_token_ids
        self.cumulative_logprob = 0.0    # 输出output的累计log probability
        self._num_computed_tokens = 0    # 已计算的prefill tokens的数量
        self._stage: SequenceStage = SequenceStage.PREFILL   # 序列状态,刚开始应该是prefill

    def update_num_computed_tokens(self, num_new_computed_tokens: int):
        """Update number of tokens computed so far."""
        self._num_computed_tokens += num_new_computed_tokens
        assert self._num_computed_tokens <= self.get_len(), (
            self._num_computed_tokens, self.get_len())
        # If all tokens are computed, it means it is in decoding phase.
        if self.get_num_uncomputed_tokens() == 0:
            self._stage = SequenceStage.DECODE
            
    def reset_state_for_recompute(self) -> None:
        """Reset the number of computed tokens from this sequence. It is
        supposed to be called when a sequence needs to be started from
        the beginning again (e.g., sequence is preempted).
        """
        self._num_computed_tokens = 0
        self._stage = SequenceStage.PREFILL

    def get_num_uncomputed_tokens(self) -> int:
        """Return the number of prefill tokens that are not computed."""
        # we use `get_len()` which includes prompt_len + output_len instead
        # of prompt_len here. This is because during recompute we need to
        # prefill for both prompt and output.
        return self.get_len() - self.get_num_computed_tokens()

    def get_len(self) -> int:
        return len(self.output_token_ids) + len(self.prompt_token_ids)

    def get_prefix_token_ids(
            self, num_tokens: int
    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
        """Get prefix tokens, and make the return value hashable"""
        # 截取num_tokens个token, prompt部分和output部分构成返回结果
        prompt_length = len(self.prompt_token_ids)
        if num_tokens > prompt_length:
            return (self._prompt_token_ids_tuple,
                    tuple(self.output_token_ids[:num_tokens - prompt_length]))
        else:
            return (self._prompt_token_ids_tuple[:num_tokens], None)

    # ...
        

2.Sequence

Sequence是一个用于存储序列数据、状态以及block信息的类。结合代码来解释一下:

  • 数据存储:self.data = SequenceData(self.prompt_token_ids)
  • 状态:self.status = SequenceStatus.WAITING;如果后续变成了FINISHED,还会涉及中止原因self.stop_reason
  • block信息:self.logical_token_blocks存放了与该序列相关的逻辑块;

上一篇已经讲了一些关于block的概念,这里回顾一下。block相当于是显存/内存管理单元,vllm将可用的gpucpu按照指定的大小block_size进行分块,每一块对应一定大小的显存/内存。block_size默认是16,意味着可以用来存储16个tokenkv cache。同时下面还涉及到插槽slots的概念,它的意思是block的一个位置,也就是可以存放一个tokenkv cache

Sequence中有一些重要的方法,我在代码中截取了部分并给出了注释,主要操作涉及:

  • self._append_tokens_to_blocks():当新的tokens来临时,将tokens添加到逻辑块blocks中;如果当前block还存在空位插槽slots,那就将当前token添加到这个block中;如果当前block已经占满了,那就需要先申请新的逻辑块,对应方法为self._append_logical_block()
  • self.fork():复刻一条新的序列;
class Sequence:
    """Stores the data, status, and block information of a sequence."""

    def __init__(
        self,
        seq_id: int,       # 该序列的id,每一个序列都会有一个id
        inputs: LLMInputs,  # 输入,数据为(prompt_token_ids, prompt, multi_modal_data)
        block_size: int,   # block尺寸,默认16
        eos_token_id: Optional[int] = None,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.seq_id = seq_id
        self.inputs = inputs
        self.block_size = block_size
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request

        self.data = SequenceData(self.prompt_token_ids)
        self.output_logprobs: SampleLogprobs = []
        self.output_text = ""

        self.logical_token_blocks: List[LogicalTokenBlock] = []
        # Initialize the logical token blocks with the prompt token ids.
        self._append_tokens_to_blocks(self.prompt_token_ids)
        self.status = SequenceStatus.WAITING           # 序列状态,这边是waiting
        self.stop_reason: Union[int, str, None] = None # 停止生成的原因

        # Used for incremental detokenization
        self.prefix_offset = 0
        self.read_offset = 0
        # Input + output tokens
        self.tokens: Optional[List[str]] = None

    def hash_of_block(self, logical_idx: int) -> int:
        """ 计算block的哈希值 """
        num_tokens = self.num_hashed_tokens_of_block(logical_idx)
        # get_prefix_token_ids是说针对当前sequence截取num_tokens个token,
        # 并且返回时prompt部分和output部分分开
        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
        return hash((hashed_tokens, self.lora_int_id))

    def num_hashed_tokens_of_block(self, logical_idx: int):
        """ 给定索引,计算到该索引为止,总的token数量"""
        return logical_idx * self.block_size + self.block_size

    def _append_logical_block(self) -> None:
        """ 添加逻辑块
        block_number: 按照顺序,当前逻辑块的序号;
        block_size: 默认16;
        """
        block = LogicalTokenBlock(
            block_number=len(self.logical_token_blocks),
            block_size=self.block_size,
        )
        self.logical_token_blocks.append(block)

    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
        cursor = 0
        while cursor < len(token_ids):
            # 在所有token没有完全填充之前持续填充
            
            # 初始情况下,没有逻辑块,直接添加逻辑块
            if not self.logical_token_blocks:
                self._append_logical_block()

            # 始终找到最后一个没有被完全填充(满)的block
            last_block = self.logical_token_blocks[-1]
            if last_block.is_full():
                self._append_logical_block()
                last_block = self.logical_token_blocks[-1]

            # 这个block甚于可填充的最多的token数量
            num_empty_slots = last_block.get_num_empty_slots()
            # 填充指定的token
            last_block.append_tokens(token_ids[cursor:cursor +
                                               num_empty_slots])
            # 更新索引,也就是下次要填充的第一个token的索引
            cursor += num_empty_slots

    def append_token_id(
        self,
        token_id: int,
        logprobs: Dict[int, Logprob],
    ) -> None:
        assert token_id in logprobs
        # 向blocks中填充指定的token
        self._append_tokens_to_blocks([token_id])
        # 更新新一位logprobs
        self.output_logprobs.append(logprobs)
        # SequenceData中同步更新相关数据
        self.data.append_token_id(token_id, logprobs[token_id].logprob)

    def fork(self, new_seq_id: int) -> "Sequence":
        new_seq = copy.deepcopy(self)
        new_seq.seq_id = new_seq_id
        return new_seq

三、SequenceGroup & SequenceGroupMetadata

1.SequenceGroup

SequenceGroup是用来表达由同一个prompt生成的序列组的这么一个类,因为针对同一个prompt可能存在多个采样,也就对应着多条sequence。主要参数包括:request_id,所有序列seqs,采样参数sampling_params和请求抵达时间arrival_time等。我们来看其中的几个方法:

  • self.prompt():所有sequences共享同一个prompt
  • self.is_prefill():所有序列都处于同一个阶段,因此只要判断任意一个sequence的阶段时prefill还是decode即可;
  • self.get_max_num_running_seqs():获取当前请求,在余下的生命周期中并行序列的的最大数量,这反映了将来时刻这个序列组可能消耗的最大内存。
class SequenceGroup:
    """A group of sequences that are generated from the same prompt."""

    def __init__(
        self,
        request_id: str,       # 请求id
        seqs: List[Sequence],  # 组中所有序列
        arrival_time: float,   # 请求抵达时间
        sampling_params: Optional[SamplingParams] = None,  # 采样参数
        lora_request: Optional[LoRARequest] = None,
        embeddings: Optional[List[float]] = None,   # 用于embedding model
        pooling_params: Optional[PoolingParams] = None,    # 用于embedding model
        encoder_seq: Optional[Sequence] = None,  # 当使用encoder/decoder模型是才会传入该参数,是单个encoder序列
    ) -> None:
        self.request_id = request_id
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
        self.sampling_params = sampling_params

        # 与请求相关的指标,主要是记录具体环节的时间
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
        self.lora_request = lora_request
        self.prompt_logprobs: Optional[PromptLogprobs] = None
        # sequence group的状态
        self.state = SequenceGroupState()
        self.embeddings = embeddings
        self.pooling_params = pooling_params
        self.encoder_seq = encoder_seq

    @property
    def prompt(self) -> Optional[str]:
        # group的prompt,也是任意一条sequence的prompt
        return next(iter(self.seqs_dict.values())).prompt
        
    def is_prefill(self) -> bool:
        # 每个序列都应该在相同的阶段
        return self.get_seqs()[0].is_prefill()

    def get_max_num_running_seqs(self) -> int:
        """对于此请求,在余下的生命周期中并行序列的的最大数量"""
        if self.sampling_params and self.sampling_params.use_beam_search:
            # beam search总是`best_of`个
            return self.sampling_params.best_of
        else:
            if (self.sampling_params and self.sampling_params.best_of > self.num_seqs()):
                # prompt阶段,只有1个;generation阶段 `best_of`个
                return self.sampling_params.best_of
            # 采样阶段,返回实际还未完成的数列数量
            return self.num_unfinished_seqs()

2.SequenceGroupMetadata

SequenceGroupMetadatasequence group的元数据, 用于创建AttentionMetadata。参数多,我已在代码中注释。注意,在prompt阶段,处理的是一整个prompt的长度,self._token_chunk_size = list(seq_data.values())[0].get_len()generate阶段由于是逐个输出,所以长度是1

class SequenceGroupMetadata:
    """Metadata for a sequence group. Used to create `AttentionMetadata`."""

    def __init__(
        self,
        request_id: str,   # 请求id
        is_prompt: bool,   # 该请求是否在prompt阶段
        seq_data: Dict[int, SequenceData],   # 序列数据,是一张表
        sampling_params: SamplingParams,     # 采样参数

        # Seq id -> list of physical block numbers,每个sequence对应的物理块
        block_tables: Dict[int, List[int]], 
        do_sample: bool = True,              # 是否是采样
        pooling_params: Optional[PoolingParams] = None,
        token_chunk_size: Optional[int] = None,    # 如果chunk,每个序列需要处理的token数量
        lora_request: Optional[LoRARequest] = None,
        computed_block_nums: Optional[List[int]] = None,  # 在prefix caching中使用,是指那些已经计算过的block
        state: Optional[SequenceGroupState] = None,       # SequenceGroup的状态
        multi_modal_data: Optional["MultiModalData"] = None,  # 多模态数据
        encoder_seq_data: Optional[SequenceData] = None,  # 如果是encoder decoder模型,需要encoder sequence数据
        cross_block_table: Optional[List[int]] = None,    # 如果有encoder,则存在cross_attention,它也有一个block_table
    ) -> None:
        self.request_id = request_id
        self.is_prompt = is_prompt
        self.seq_data = seq_data
        self.sampling_params = sampling_params
        self.block_tables = block_tables
        self.pooling_params = pooling_params
        self.lora_request = lora_request
        self.computed_block_nums = computed_block_nums
        self.multi_modal_data = multi_modal_data
        self.state = SequenceGroupState() if state is None else state
        self.encoder_seq_data = encoder_seq_data
        self.cross_block_table = cross_block_table
        self._token_chunk_size = token_chunk_size
        self.do_sample = do_sample

        # 投机采样,暂时忽略
        self.num_speculative_tokens = None

        if self._token_chunk_size is None:
            if is_prompt:
                # prompt阶段,处理的是一整个prompt的长度
                self._token_chunk_size = list(seq_data.values())[0].get_len()
            else:
                # 逐个token处理
                self._token_chunk_size = 1

总结

本篇主要介绍了SequenceSequenceGroup等相关内容,涉及了sequence的状态status标记,所处调度阶段stage,如何为新的需要计算的tokens分配blocks等等。从下一篇开始,我们将会看到vllm是如何应对大量requests请求的,敬请期待。

你可能感兴趣的:(人工智能,language,model,nlp,python,transformer)