https://arxiv.org/abs/2105.12485
Comments: Accepted by UAI2021
Subjects: Machine Learning (cs.LG); Programming Languages (cs.PL)
Cite as: arXiv:2105.12485 [cs.LG]
现有挑战:
设计适当的机制来学习程序的语法结构
代码是强结构化的,代码的语义依赖于要表示的具有不同语法结构的程序语句和表达式的组合,不能仅仅采用类似于自然语言的处理方法(简单的将代码建模为单词序列)。
如何使用AST作为预训练模型的输入?
树形结构的预训练任务探索
面向序列的任务直接应用于非顺序结构化AST中存在一些不恰当的问题,因此,需要为树设计新的预训练任务,使预训练模型能够同时从AST中提取语法和语义信息。
主要贡献:
采用transformer的编码器-解码器结构。修改了Transformer的编码器端,只添加了一个完全连接的层来调整输入的维度。
AST以树的形式展示了程序的语法结构。树中的每个节点表示代码中的一个结构。
AST节点分为两类:
用从根节点到叶节点的路径集合表示AST,A = { p1,p2,…,pN },N表示AST中路径的个数。
与AST相应的代码片段被分割成一系列 tokens,[LT] 和 [CLS] 分别被添加在序列的开始和结尾处。
C = [LT],x1,x2,…,x3,[CLS] ,其中 [LT] 是 [LT] 的向量表示, [CLS] 是 [CLS] 的向量表示,M是代码片段的长度。
C 被使用在解码器的输入。[EOS] 是解码器端的句尾标识符。
[LT] 不仅作为解码器端的句子开头标识符,它的值还表示目标编程语言的类型。例如, [LT] = [PLT] 表示语言类型是 Python,[LT] = [JLT] 表示语言为Java,当 [LT] = [UNK] 时表示编码器生成的语言是在预训练阶段未见过的语言。这样定义是因为在将代码片段转换为AST时,隐藏了不同类型语言的实现细节,我们需要提示语言类型,以便模型了解不同编程语言之间的差异。
使用 [CLS] 作为NOP的聚合表示。
每一个 path 是一个 nodes 序列, p i = v 1 i v 2 i . . . v L − 1 i x t i p_i=v^i_1v^i_2...v^i_{L-1}x^i_t pi=v1iv2i...vL−1ixti,path 上的叶子节点 x t i x^i_t xti 是对应的代码片段的一个 token,L是 path 的长度。
我们将路径上的节点向量连接起来以表示路径:
p i = C o n c a t [ v 1 i ; v 2 i ; . . . ; v L − 1 i ; x t i ] ; p_i=Concat[v^i_1;v^i_2;...;v^i_{L-1};x^i_t]; pi=Concat[v1i;v2i;...;vL−1i;xti];
路径集合中的路径表示向量之间没有排序关系。因此,与标准Transformer不同,我们的模型编码器端不添加位置编码来为路径向量分配位置信息,而是在形成节点表示时使用节点位置嵌入来添加树中节点的位置信息。
使用字节对编码(BPE),从AST的值节点和代码片段中学习最常见的 subtoken,并对其进行切片,例如 “third_party” 可能被切片成 “third” ,“-” 和 “party”,使用过程每个token 的所有 subtoken 的向量和来表示完整的 token。
AST中的类型节点数量固定且较少,直接通过embedding将其表示为 实值向量。
一个节点的 position embedding 是其父节点的 position embedding 与它相应的 level embedding 的线性组合。
由 H+1 个 level embedding 作为参数,即 W l e v e l W^{level} Wlevel,其中H为树的高度。我们使用 W 0 l e v e l W^{level}_0 W0level作为根节点的 parent position embedding。如果在 第 j 层有一个节点,它的 position embedding 是 W p a r e n t W^{parent} Wparent 并且它由 c 个子节点,那么它的第 i 个子节点的 position embedding 表示为:
其中 W p a r e n t W^{parent} Wparent , W l e v e l W^{level} Wlevel 是可学习的线性矩阵。Node Posiotion Embedding 可以获得层次信息和节点的父节点和兄弟节点的相对位置信息。
给定 AST-code 片段对,提出一种屏蔽AST中的节点和代码片段的tokens的策略。
在编码器端,首先根据概率 { q n i } n = 1... L \{q^i_n\}_{n=1...L} {qni}n=1...L的分布对路径 p i p_i pi 上的节点进行采样,并使用 TOPK() 操作去选择概率最大的k个节点 m i A m^A_i miA,然后用一个特殊 token [mask] 替换路径 p i p_i pi 中的这些节点,得到 p i m a s k e d p^{masked}_i pimasked
其中 A m a s k e d A^{masked} Amasked 代表 masked 路径的结合, l l l 是当前的节点层次, L L L 是路径中最大的节点层数, N N N 为AST中包含的路径个数, i = 1... N i=1...N i=1...N。注意, L L L 被减去是为了防止数值溢出,这确保了路径中在较大层次的节点被屏蔽的概率更高。
TMLM以更高的概率屏蔽路径中靠近终端的节点。主要原因是:
在解码器端,解码器的输入 C m a s k e d C^{masked} Cmasked是通过屏蔽代码片段中的 tokens 获得的,屏蔽公式如下:
其中 m A = m 1 A ∪ m 2 A ∪ . . . ∪ m N A m^A=m^A_1∪m^A_2∪...∪ m^A_N mA=m1A∪m2A∪...∪mNA, x x x 是集合 m C m^C mC 中需要被 mask 的元素。
我们保留与 m A m^A mA 中的值节点对应的 tokens,屏蔽代码片段 C 中的其他节点。这样,通过下一个 token 的预测,TMLM可以强制解码器依赖于 AST 的特征表示,而不是代码片段中的 previous token。
下图显示了一个示例,灰色节点意味着节点被 masked。根据前面的策略,在AST 中,四条路经被 masked 的节点集合为: m 1 A = { v 4 , x 1 } , m 2 A = { x 2 } , m 3 A = { } , m 4 A = { x 4 } m^A_1=\{v_4,x_1\},m^A_2=\{x_2\},m^A_3=\{\},m^A_4=\{x_4\} m1A={v4,x1},m2A={x2},m3A={},m4A={x4},在代码片段中,值节点 x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3 是被给出的,其它节点是被 masked 的,即 m C = { x 3 , x 5 } m_C=\{x_3,x_5\} mC={x3,x5}。解码器需要做的是预测完整的代码片段 x 1 , x 2 , x 3 , x 4 , x 5 , x 6 x_1,x_2,x_3,x_4,x_5,x_6 x1,x2,x3,x4,x5,x6。
在TMLM中,**编码器读取被屏蔽的AST路径集合,然后解码器推断出与AST对应的代码片段。**当代码转换为AST时,隐藏了一些语义信息,如“+”,“>”,“<=”等二进制操作符在AST中使用“BinOpSub”节点表示。在这种情况下,如果解码器被设计为预测AST,则上述语义信息将被忽略。因此,我们设计解码器来预测代码片段,以鼓励模型推断这些语义信息,从而增强其在下游任务中的泛化能力。
总之,TMLM可以强制编码器理解AST并推断隐藏在AST中的语义信息。
为了进一步提高从程序中提取语法结构信息的能力,我们设计了二值化预训练任务NOP。
AST中节点的顺序有一些隐式约束。以上图中的AST结构为例,“if”节点下必须有一个“body”节点,“body”节点下必须有一个“Expr”节点。为了获取这种语法结构信息,我们以一定的概率决定是否随机交换路径中某些节点的位置,然后训练模型来区分AST中节点的顺序是否正确。如图2所示,我们交换节点v3和v5的位置(图中绿色的节点代表交换位置)。
[CLS] 的隐向量通过一个全连接层压缩到一维,然后通过sigmoid函数得到AST路径中无需节点存在的概率 y’
为了验证TreeBERT的有效性,TreeBERT对两个生成任务进行了微调,并与基线进行了比较。生成任务是代码总结和代码文档。我们还评估了TreeBERT在c#数据集上的性能,并通过实验证明TreeBERT可以很好地推广到预训练阶段未见的编程语言。
改进策略: