放在最前:
Relax 的关键创新
深度学习模型(比如 ChatGPT这种大模型)在运行时经常遇到“输入尺寸不固定”的情况。比如你问它一个问题,这次输入是10个字,下次可能是100个字。传统编译器处理这种“变来变去”的尺寸很笨——要么只能按固定尺寸优化(导致变尺寸时性能暴跌),要么每次都要重新编译(慢到没法用)。
Relax 的创新:
符号形状:让编译器学会“代数” Relax 允许编译器用“符号变量”(比如
n
)表示未知的尺寸,就像代数里的未知数。比如告诉编译器:“这个张量的形状是(n, 4)
,另一个是(n+1, 4)
”。这样编译器就能理解它们的尺寸关系,像做数学题一样优化内存和计算流程,而不是两眼一抹黑。跨层级优化:把大任务拆小,再拼起来 传统编译器优化像“一刀切”——要么全用高级抽象(但性能差),要么全用底层代码(但改不动)。Relax
允许同时用高级和低级代码,比如把一部分计算合并成高效内核(底层),另一部分保持灵活(高层)。就像修车时既能用现成零件,也能自己造零件,组合出最优方案。
接下来讲我们的故事。
在接触深度学习编译器之前,很多小伙伴会问:“为什么我们不能直接用 GCC、LLVM 这类成熟的通用编译器?”
因此,深度学习编译器(ML Compiler) 兴起:如 TVM、XLA、MLIR、IREE 等。它们往往具备以下能力:
现阶段,大模型(如 GPT-4、Llama2、CodeLlama 等)在对话、文本生成、AI 助手等方面迅猛发展。这类模型动辄数十亿到上千亿参数,且在推理时需要处理可变长度的上下文,采用Key-Value Cache等机制。这就带来了“动态形状(Dynamic Shape)”需求:
unique
, non_zero
之类);这些特性使得编译器如果不能很好地处理动态形状,就会大幅退化:
基于此,业界与学术界逐渐认识到:机器学习编译器需要更精细的动态形状感知能力,不再简单地把形状标记为 any
或 -1
,而应该能够在 IR 中把形状的符号维度与算术关系“显式”地表示出来,并支持相应的运算、优化、内存复用调度等。
Relax 就是为了解决上述问题而提出的一种编译 IR(中间表示),它来自 TVM 项目,是继 Relay IR 之后专门为“动态形状–aware”而设计的下一代 IR。
一等公民的符号形状(First-class symbolic shape)
(n, 4)
, (n+1, m)
, (n//2, 256)
等);在同一个 IR 中同时表示高层图级与底层算子级
可组合的动态形状优化
这些特点使得 Relax 能够比前代 IR(如 Relay)在大模型和复杂动态场景中拥有更灵活、更强劲的性能表现。
-1
或 symbolic name)。unique
),则插入 match_cast
断言或动态检查。对编译器而言,这些阶段可以是一系列Pass,有些 Pass 可能会重复数次,或者先做融合再做部分形状推断等,现实中可能比上图更复杂。不过,对初学者只需把握此“自上而下、分阶段优化”的大图即可。
本节我们用更多示例与短代码片段,帮助读者理解 Relax 的两大核心抽象:第一类符号形状(symbolic shape)与跨层 IR(cross-level IR)。学习编译器 IR 可能会略显抽象,但理解这部分对掌握后续的优化机制很关键。
在传统编译器或 ONNX、Relay 等前代系统中,遇到维度不确定时,常以 -1
或 any
表示。但这样会使编译期几乎无法进一步推断,导致大量潜在优化失败。Relax 的做法是:
Tensor((n, 4), "float32")
;n
是一个符号变量,代表“此维度在编译期不知道具体值,但它是所有后续算子里相同的那个 n”;n
,编译器就能判断它们形状上的关联,从而做正确的融合、内存复用等。例如,一个函数签名可以是:
def subfunc(x: Tensor((n, 4), "f32")) -> Tensor((n*4,), "f32"):
# x 的形状是 (n,4),编译器知道 n 是符号维度
...
# 返回一个形状 (n*4,) 的张量
return y
这里 (n*4,)
也是一个符号表达式,表示 1D 张量,长度为 n*4。
(n*4,)
。match_cast
与类型断言并不是所有维度都能在编译时精确得知表达式。某些算子(比如 unique
)的输出大小只有在运行时才知道。这时为了继续后续的形状推断,可以写类似:
# 伪代码示例
def unique_and_exp(x: Tensor((n,), "f32")) -> Tensor((m,), "f32"):
lv0 = unique(x) # output shape 不确定
lv1 = match_cast(lv0, Tensor((m,), "f32"))
# match_cast 表示把 lv0 的形状“匹配”为 (m,)
# 如果运行时发现 lv0 并非 1D,或与 m 不一致,就会报错
return exp(lv1)
match_cast
告诉编译器:“我期望 lv0 最终是一维张量,并把其长度符号标记为 m。” 在编译后,执行时如果实际大小不对,就会触发动态断言错误。一旦通过,则编译器就能用 m
去表示后续算子的形状。
在深度学习中,“算子级”常常是“CPU/GPU Kernel 代码、卷积/矩阵乘法内核、元素级 for-loop 程序”等;而“图级”则是算子之间的数据流。
call_tir
:调用自定义或自动生成的 TensorIR 函数;call_dps_library
:调用外部已经写好的库函数,采用 Destination-Passing Style(DPS)给定输出 buffer。也就是说,Relax 在图级的函数中,可以直接把若干节点替换成一次 “call_tir(…)”,对应一个底层循环级实现。后续若还想做跨算子融合或调度变更,可以再做局部更新,而不必重新从头编译。
由于可以在图级 IR 与算子级 IR“共存”,所以 Relax 可以分多阶段地把一部分图节点转换为 call_tir,保留另一部分依旧是图节点:
本节我们在介绍完 Relax 的核心抽象后,来看两项在大模型或变形模型里尤其重要的编译优化:算子融合和内存规划。这两者如果结合“动态形状意识(symbolic shape)”来做,往往能省下大量计算与显存资源。
深度学习中,很多小算子如果单独执行,需要反复把中间结果写回到显存再读出来,导致大量数据搬运开销。而把多层 element-wise(如 ReLU、Add、LayerNorm 的局部操作)或者某些常见操作(如 MatMul + BiasAdd + ReLU)合并为一个 kernel,可以大幅减少中间数据的 IO,提升 GPU 利用率。
要做融合,需要确认被融合的算子在形状上可匹配。如果二者在编译期就是完全静态如 (128,256) -> (128,256)
,这很简单。但在动态形状中,要判断 (n,256)
和 (n,256)
是不是同一个 n 还是两个无关符号,就需要依赖 Relax 对符号的统一管理。如果编译器识别它们共享同一个 n
,就可安全融合。否则只能分开。
比如:
# 假设前面已有 lv0: Tensor((n,256), "f32")
lv1 = relu(lv0) # shape (n,256)
lv2 = add(lv1, 1.0) # shape (n,256)
如果这两个算子是都在 IR 中定义为 call_tir(“relu_fn”)、call_tir(“add_fn”),并且编译器看到二者形状都是 (n,256)
,则可以把二者合并为一个新的 fused kernel:call_tir(“relu_add_fused”)。
融合后,Relax 会生成一个新的 TensorIR 函数,如 fused_relu_add
. 其循环伪代码可能长这样:
@tensorir_function
def fused_relu_add(X: Buffer((n,256), "f32"), Out: Buffer((n,256), "f32")):
for i, j in grid(n, 256):
# ReLU
tmp = max(X[i, j], 0.0)
# Add(1.0)
Out[i, j] = tmp + 1.0
随后,高层 IR 会把对 relu(…) 和 add(…) 的两个调用替换成对 fused_relu_add
的一次性调用,大大减少内核启动和中间存储的消耗。对大模型而言,类似的融合技巧能显著提高吞吐率。
在大模型推理或训练中,往往有很多中间张量的大小都取决于符号维度(如 n*256
、(n+1)*256
等)。如果编译器能够在图级知道某些张量“不会同时被用到”,并且它们的形状大小在运行时可确认相同,就能将它们复用同一块显存 buffer。
对静态形状,如 (128,256)
与 (128,256)
,很简单。对动态形状,如 (n,256)
与 (n,256)
:
n
是同一个符号,就表示大小相等;(n+1, 4)
,另一个是 (4n+4,)
,编译器可以做算术简化,判断其是否总相等;在大语言模型里,KVCache 占用大量显存,并且在对话生成时会不断扩充 shape。如果编译器能对其“持续存在”的部分做一次性分配,然后对其他一些临时张量做复用,就可以显著降低峰值显存。对于在手机端部署大模型这种极端场景,这种编译级内存规划尤为重要。
本节我们通过一个“简化的 LLM 推理”案例,示例哪些算子需要动态形状,以及 Relax 如何自动完成优化。案例中并不会展示完整的大语言模型(那包含非常多层),而是强调动态部分的核心思路。
past_len
会从 0 开始,一步步增长,每次生成一个新 token 就将 KVCache 扩充到 (past_len+1, hidden_dim)
。在 Relax 中,形式可能是这样:
def forward_block(x: Tensor((batch, seq_len), "int32"),
past_kv: Tensor((batch, heads, past_len, dim_per_head), "f16")):
# 1. Embedding => (batch, seq_len, hidden_dim)
emb_out = call_dps_library("my_lookup_table", [x], shape=(batch, seq_len, hidden_dim))
# 2. Self-Attention => 需要 (batch, seq_len, hidden_dim) + past_kv
# 输出新 token 的隐状态,以及更新后的 kv
attn_out, new_kv = call_dps_library("cutlass_attention",
[emb_out, past_kv],
# 形状注释仅示例
out_shape_attn=(batch, seq_len, hidden_dim),
out_shape_kv=(batch, heads, past_len+seq_len, dim_per_head))
# 3. 对 attn_out 做一些 element-wise 操作,如 ReLU
relu_out = call_tir(my_relu, [attn_out], Tensor((batch, seq_len, hidden_dim),"f16"))
# 4. ... 省略更多算子
return relu_out, new_kv
核心看点在于:
(batch, seq_len, hidden_dim)
这类形状为一组动态符号 (b, s, d)
,并与 KVCache 维度 (b, h, p, d')
建立关联。(b, s, d)
与其他同形张量复用(不与 KVCache 复用,因为 KVCache 需要长期保留)。seq_len=1
,past_len
会随轮数不断增加,但 kernel 并不退化为最通用模式,因为 “d
, b
等大部分维度是已知或可做静态特化”,只在 p
维度上做循环。据Relax 论文和社区提供的结果:
除服务器 GPU 场景外,Relax 也非常关注在手机、平板、IoT 设备等硬件上的部署。因为这些设备内存更有限,且可能需要动态形状(如可变输入分辨率、可变语音片段长度、多轮对话上下文等):
这样一来,端侧大模型(如在 iOS、Android 上跑 Llama2 7B)的解决方案更加可行:
编译器 | 动态形状支持 | 跨层抽象 | 典型特点 |
---|---|---|---|
XLA | 部分支持,但限制较多 | 以静态图为主 | 专注于 TPU 上大规模训练/推理;动态形状处理不够灵活 |
Relay(TVM) | 有 any 或 -1 标记 |
单向降级 | 传统上更偏静态形状,对动态形状支持有限 |
MLIR | 取决于具体 Dialect | 无统一跨层 IR | 通用编译基础设施,需自行扩展和封装 |
PyTorch compile | Symbolic shape + Inductor | 没有跨层共享 IR | 主要在 Python 端合并算子,算子级仍依赖库或生成 IR |
Relax | 完整“符号形状”概念 | 图级 + 算子级同层 | 动态形状一等公民,支持可组合的降级与优化 |
从上表可见,Relax 的突出特征在于把“图级 IR”和“算子级 IR”统一在一个 IR 系统中,并且对动态形状的表达力很强,这使得它在应对 LLM 推理、多变输入形状、端上部署等问题时更加从容。
下面给出一个更细的示例,涵盖定义 TensorIR 函数、图级调用、以及如何融合的大致流程,帮助大家直观理解。
@tensorir_function
def my_relu(
X: Buffer((n, m), "f32"),
Y: Buffer((n, m), "f32")
):
# 这里 (n, m) 是符号形状
for i, j in grid(n, m):
with block():
# 计算块
Y[i, j] = max(X[i, j], 0.0)
此处 my_relu
函数描述了如何在低层循环级别执行 ReLU,用 for 循环遍历 (n, m),对每个元素做 max(0, x[i,j])
。在实际生成 GPU 代码时,编译器会根据调度将这个 for-loop 分配到 GPU block/thread。
def main_fn(x: Tensor((n, 4), "f32"),
w: Tensor((4, 8), "f32")) -> Tensor((n, 8), "f32"):
# x shape = (n,4), w shape = (4,8)
with dataflow():
# 假设有一个 MatMul 的 TensorIR 函数 matmul_fn
lv0: Tensor((n, 8), "f32") = call_tir(
matmul_fn,
[x, w],
Tensor((n,8), "f32") # 输出形状注释
)
# 再调用自定义 ReLU
lv1: Tensor((n, 8), "f32") = call_tir(
my_relu,
[lv0],
Tensor((n, 8), "f32")
)
return lv1
call_tir(matmul_fn, [x, w], ...)
表示调用我们编写/调度好的矩阵乘法内核;(n,8)
张量后,再通过 call_tir(my_relu, [lv0], ...)
调用上面定义的 ReLU 内核;(n,4)
, (4,8)
, (n,8)
这几个形状的符号注释,可用于后续融合或内存复用分析。如果编译器发现 matmul_fn
与 my_relu
都是 element-wise 在输出张量 (n,8)
上执行,那么就可决定合并二者:
fused_matmul_relu
的 TensorIR 函数,把 matmul 的结果值在同一个 kernel 中顺便做 ReLU;call_tir
则更新为 call_tir(fused_matmul_relu, ...)
,减少一次内存读写。在实际大模型里,会出现类似“MatMul + LayerNorm + Dropout + Add”之类序列,只要编译器判定形状与数据流吻合,就能通过动态形状符号把它们合并到一个 kernel。
Relax 还在不断演进:
总之,以符号形状为核心、跨层级抽象为骨架的编译体系,将会在未来支撑更灵活、更大规模的深度学习场景。
通过本教程,相信读者可以初步理解:深度学习编译器为何需要在 IR 中“显式地表达”动态形状;又为何“跨层次”统一抽象如此关键。Relax 在这两点上都做了深入设计,并结合了可组合优化(算子融合、内存规划、自动调度、外部库调用)的理念,形成了面向大规模、动态模型部署的强力编译体系。
在实际使用中,TVM 社区也为 Relax 提供了大量实例与工具,包括:
读者如果对这条技术路线感兴趣,可前往 TVM 官方文档 或 Relax 社区 了解最新进展。在大模型浪潮下,让我们期待 Relax 及相关编译技术,能为更多开发者带来动态、高效、可移植的深度学习部署体验。