[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification

Contents

  • Introduction
  • Method
    • Speculative Inference
      • Collective Boost-Tuning
      • Learning-based Speculative Scheduler
    • Token Tree Verifier
      • Tree Attention
      • Verification
      • Optimizations
  • Evaluation
  • References

Introduction

  • LLMs 的高内存和算力需求使得构建响应迅速且低成本 (quickly and cheaply) 的 LLMs 推理系统是十分困难的,例如 GPT-3 有 175B 的参数,如果用 FP32 存储需要耗费至少 16 张 40GB A100 GPUs,并且由于推理是自回归的,还需要数秒才能处理完一次推理需求
  • 现有的 LLM 推理系统一般是采用 incremental decoding,即先用一个迭代步计算完所有 prompt tokens (i.e., 输入文本) 的激活值,然后根据 input prompt 和所有之前生成的 tokens,在每个 step 迭代地解码出一个新的 token. 由于生成一个新 token 就需要一个 step,这种方法推理速度受限,并且 GPU 使用率也不高
    [Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第1张图片[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第2张图片
  • 此外,由于生成新 token 时需要使用前序 tokens 的 keys 和 values,为了避免在每个 step 都重新计算前序 tokens 的 keys 和 values,LLM 推理系统通常会使用 K-V cache,将已经计算完成的 keys 和 values 存起来,但对于长序列生成任务 (e.g., GPT-4 supports up to 32K tokens in a request),这会产生巨大的内存开销,使得系统无法并行处理大量的推理需求
  • 为了解决上述问题,作者构建了 LLM serving system – SpecInfer,利用 speculative inferencetoken tree verification 来改善 LLM 推理的端到端时延以及 computational efficiency (use an LLM as a token tree verifier instead of an incremental decoder). 具体来说,SpecInfer 集成了不同的 collectively boost-tuned small language models 来预测 LLM 的输出,多个 SSMs 的预测结果被构建为 token tree 的形式,token tree 上的每个结点都代表一个 small language model 预测的候选 token 序列。token tree 中所包含的所有候选 token 序列的正确性由 LLM (i.e., token tree verifier) 进行验证,并且作者采用了 tree-based parallel decoding,使得所有序列的验证过程是并行的。如果 token tree 中有候选 token 序列和 LLM 的输出一致,就可以在一个 step 内生成多个 tokens
  • 作者在 two LLM families (i.e., LLaMA and OPT) 和 5 个 prompt 数据集上对 SpecInfer 进行了评估,实验表明 SpecInfer 最多能使得 LLM decoding steps 数减少 4.4 × \times × (3.7 × \times × on average),并且能最高降低端到端推理时延 2.8 × \times ×
    [Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第3张图片

Method

[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第4张图片

Speculative Inference

[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第5张图片

  • speculator 中包含的 SSMs (small speculative models) 可以是 (1) Fine-tuned SSM. 与 LLM 在同一数据集上预训练,但模型更小;(2) Quantized/Distillated/Pruned LLM;(3) Knowledge Retriever/User-defined function. Predict future tokens based on heuristics and/or retrieval-augmented documents.

Collective Boost-Tuning

[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第6张图片

  • SpecInfer 首先对 SSMs 进行无监督微调来使得它们的输出与 LLM 的输出对齐 (we perform collective boost-tuning offline)。具体来说,SpecInfer 使用 OpenWebText corpus 作为数据集,将数据集转化为一系列的 prompt 样本,然后用 LLM 为每个 prompt 生成一个 token 序列。SpecInfer 首先对第一个 SSM 进行微调使得其输出尽可能与 LLM 一致,然后过滤掉所有 SSM 与 LLM 产生相同 token 序列的 prompt 样本,用剩余 prompt 样本去微调下一个 SSM. 不断重复上述过程直至微调完所有 SSMs,SpecInfer 就得到了一组在训练集上的 aggregated output 与 LLM 的输出具有高一致性,且具备多样性的 SSMs
  • 所有 SSMs 的推理时延都差不多,而 SpecInfer 在不同 GPUs 上并行运行所有 SSMs,因此相比运行单一 SSM,时延也不会增加。同时,虽然使用多个 SSMs 会增加 GPUs 上的内存开销,但 SpecInfer 通过使用比 LLM 小 40-100 × \times × 的 SSMs 就能够带来显著的性能提升,且单个 SSM 只会增加 1-2% 的内存开销

Learning-based Speculative Scheduler

  • 每个 decoding step 的 token 生成难度可能是不一样的,并且使用很多 SSMs 也会导致 large token tree,会增加 verification 的内存和算力开销,而 learning-based speculative scheduler 就是在每个 step 选择使用哪些 SSMs 进行推理以及这些 SSMs 的 speculative configurations 来动态地平衡 SSMs 的生成质量和生成速度

  • matching length predictor. SpecInfer 使用 3 层 MLP (hidden size 为 64) 去预测每个 SSM 在不同 beam search 设置 (beam width b ∈ [ 1 , 2 , 4 ] b \in [1, 2, 4] b[1,2,4], beam depth d ∈ [ 1 , 2 , 4 , 8 , 16 ] d \in [1, 2, 4, 8, 16] d[1,2,4,8,16]) 下的 expected matching length,输入为 LLM 最后的 hidden layer 输出的 feature h h h,输出为 f ( ⋅ ∣ h ) ∈ R 15 f(\cdot| h)\in\R^{15} f(h)R15. 并且 predictor 也是提前在线下训练好的 (We train the predictor on 200K samples over the OpenWebText corpus.)
  • cost model. 为了选择具有更高 matching length per unit time 的推理设置,作者定义了如下 cost function:
    在这里插入图片描述在这里插入图片描述其中, L verify ( b , d ) , L speculate ( b , d ) L_{\text{verify}}(b,d),L_{\text{speculate}}(b,d) Lverify(b,d),Lspeculate(b,d) 分别为 verifier 和 speculator 在给定 beam search 配置下的估计推理时延,可以通过 profiling the SpecInfer runtime system 得到

Token Tree Verifier

[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第7张图片

Tree Attention

  • LLMs 一般使用注意力机制处理 token 序列
    在这里插入图片描述在这里插入图片描述
  • Tree attention. 给定 token tree N \mathcal N N,我们的目标是用 LLM 高效地计算出 token tree 中每个结点 u ∈ N u\in\mathcal N uN 的 next token O ( u ) \mathcal O(u) O(u) 用于之后的 verification (具体方法在 “Optimizations” 一节中展开)

Verification

  • 下面的 VERIFY 函数描述了 Specinfer 的验证过程。对于结点 u ∈ N u\in\mathcal N uN,如果 u u u 存在一个子结点 v v v (i.e., p v = u p_v = u pv=u), v v v 对应的 token 和 tree attention 中 LLM 计算的 next token O ( u ) \mathcal O(u) O(u) 一致 (i.e., t v = O ( u ) tv = \mathcal O(u) tv=O(u)),则结点 u u u 验证通过,可以令 u = v u=v u=v,继续验证子结点,否则验证不通过,将 O ( u ) \mathcal O(u) O(u) 作为 u u u 的下一个结点并终止验证. 从根结点开始不断重复上述过程即可得到输出的 token 序列

[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第8张图片

Optimizations

[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第9张图片

  • Sequence-based Decoding. LLM 推理系统一般使用 KV cache 来避免重新计算前序 tokens 的 keys 和 values,而验证整个 token tree 的多个 token 序列带来的一个问题就是不同 token 序列可能包含冲突的 KV caches,例如上图中的 ( t 2 , t 3 , t 4 , t 5 ) (t_2,t_3,t_4,t_5) (t2,t3,t4,t5) ( t 2 , t 3 , t 8 , t 9 ) (t_2,t_3,t_8,t_9) (t2,t3,t8,t9) 的 3、4 个 token 不同,要存的 key/value 值也不同。一种简单的解决方式就是对 token tree 上的每个 token 序列都使用不同的 KV cache,但这样会 (1) 需要多个 KV caches;(2) token tree 上的不同序列可能会有相同的 prefix,对于 prefix 而言会产生冗余计算
  • Token-based DecodingDepth-first search to update key-value cache. SpecInfer 采用了 depth-first search 的顺序遍历 token tree,从而可以使用单个 KV cache 完成验证,并且不必重复计算。Token-based Decoding 即为按照 DFS 顺序依次计算每个 token 的输出
  • Tree-based parallel decoding. 在 Token-based Decoding 中,由于每个 GPU kernel 只计算一个 token 的输出,因此会导致 high GPU kernel launch overhead. 为此,SpecInfer 采用了 tree-based parallel decoding,每个 token 序列上的所有 token 的输出都可以并行计算,并且不同 token 序列之间还可以共享 KV cache,避免冗余计算 (论文里的这句话看不太懂:“To batch attention computation, SpecInfer uses the key-value cache of the kernel’s last token (i.e., t 5 t_5 t5 for the first kernel), which results in attention scores that violate the casual dependency. SpecInfer then fixes the attention scores for these pairs”) (还有一个问题就是论文中提到 “tree-based parallel decoding algorithm to simultaneously verify all tokens in a speculated token tree in a single LLM decoding step.”,不太明白为什么可以在一个 step 内验证完)

Evaluation

  • End-to-end Performance.
    [Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第10张图片
  • Collective Boost-Tuning.
    [Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第11张图片[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第12张图片[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第13张图片[Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第14张图片
  • Learning-based Speculative Scheduler. Using the predictor can achieve similar LLM runs while reducing the SSM runs significantly due to dynamic speculation length.
    [Arxiv 2023] SpecInfer:Accelerating LLM Serving with Speculative Inference + Token Tree Verification_第15张图片

References

  • Miao, Xupeng, et al. “SpecInfer: Accelerating Generative LLM Serving with Speculative Inference and Token Tree Verification.” arXiv preprint arXiv:2305.09781 (2023).
  • code: https://github.com/flexflow/FlexFlow/tree/inference

你可能感兴趣的:(模型部署,Arxiv,2023)