在深度学习框架中,Expression Layer(表达式层)是指一个通用的算子,其允许深度学习网络的不同层之间结合和嵌套,从而支持一些更复杂的操作,如分支之间的加减乘除(elementAdd
)等。
在PNNX
的Expession Layer
中给出的是一种抽象表达式,会对计算过程进行折叠,消除中间变量,并且将具体的输入张量替换为抽象输入@0, @1等。PNNX生成的抽象表达式是这样的:
add(@0,mul(@1,@2))
add(add(mul(@0,@1),mul(@2,add(add(add(@0,@2),@3),@4))),@5)
这就要求我们需要一个表达式解析
和语法树构建
的功能。
表达式解析:将表达式拆分成多个token。
语法树构建:这是一个二叉树,以操作符作为根节点,可根据后序遍历生成逆波兰表达式。
因此,我们定义了如下类型:
TokenType 是一个枚举类型,用于表示不同的符号类型
。该枚举包含了 TokenUnknown、TokenInputNumber、TokenComma、TokenAdd、TokenMul、TokenLeftBracket、TokenRightBracket 七种类型。
Token 是一个结构体,表示每个具体的 Token
的信息,包括 Token 类型、Token 开始和结束的位置 3 个属性。
TokenNode 是一个结构体,表示语法树的节点
,包括节点存储的数值、左子节点和右子节点 3 个属性。
ExpressionParser 是一个生成符号表达式语法树
的类,包括 Tokenizer()、Generate()、tokens()、token_strs() 等方法,用于对输入的字符串进行分析和处理,并生成符号表达式语法树。
表达式解析器 ExpressionParser 的 Tokenizer() 用于对输入的字符串进行词法分析。
即,词法分析将输入字符串切分为多个 token,使得后续的语法分析更具有可操作性。
具体如下:
Tokenizer() 方法的参数 re_tokenize 用于控制是否需要对已经存在的 tokens 进行重新处理,如果不需要重新处理且 tokens 已经存在,则直接返回。这样可以避免在不需要处理时多次调用词法分析函数。
通过 std::isspace© 对每个字符进行检查,去除所有空格和制表符。
对于不同的字符,Tokenizer() 方法使用 if/else 分支结构判断,生成不同类型的 token,包括 TokenAdd、TokenMul 和 TokenInputNumber。
- 当字符为a的时候,我们的词法规定在token中以a作为开始的情况只有add,所以我们判断接下来的两个字符必须是d和d,如果不是的话就报错,如果是的话就初始化一个新的token并进行保存。
- 当字符为m的时候,我们的语法规定token中以m作为开始的情况只有mul,所以我们判断接下来的两个字必须是u和l,如果不是的话就报错,是的话就初始化一个token进行保存mul。
每当分析出一个 token 后,就将其记录到 tokens_ 和 token_strs_ 中,以备后续语法分析过程的使用。tokens_ 和 token_strs_ 分别存储 token 和对应的文本形式字符串。
Tokenizer() 方法的主要作用是对输入的字符串进行切割,将其划分为多个基础的符号,方便后续的语法分析进一步处理。
通过这种方式,表达式解析器可以更好地理解输入的表达式,从而为相应的计算程序提供原始数据和计算逻辑。
ExpressionParser 的 Generate_() 方法用于生成符号表达式语法树。
Generate_() 方法是一个递归函数,用于从输入的 Token 序列中按照先序遍历
顺序生成符号表达式语法树的节点,并返回根节点。
其中:
Generate_() 方法接受一个整型引用参数 index,用于指示当前正在处理的 token 的下标,在递归结构中传递以保证后续的深度遍历。
Generate_() 方法首先对 index 所指向的 token 进行检查,如果这个 token 不是数字、加号或乘号,则抛出异常。
如果当前 token 是数字
,则从 statement_ 字符串中提取数字并创建一个没有子节点的叶子节点
,并返回该叶子节点的指针。
如果当前 token 是加号或乘号
,则需要生成一个新节点并将其 left 和 right 子节点连接起来。然后递归调用本身
以生成其 left 和 right 子节点,并将指向子节点的指针赋值给新节点的相应成员变量。
在处理加号或乘号情况时,需要保证当前 token 后面的第二个 token 是左括号,如果不是则抛出异常;而且还需要处理其中的 left 和 right 子节点,左节点与右节点之间用逗号隔开。
最后,Generate_() 返回该树的根节点。
简单来说:
遍历tokens_容器中的每一个token
- 如果当前Token类型是数字,那么需要创建并返回TokenNode
- 如果当前Token类型是mul或者add,那么需要向下递归构建对应的左子树和右子树
- 1,判断是不是“(”
- 2,获取@num
- 3,递归调用自身,创建左子树
- 4,判断是不是“,”
- 5,获取@num
- 6,递归调用自身,创建右子树
- 7,判断是不是“)”
Generate_() 方法实现了从输入 Token 序列构建符号表达式语法树的逻辑,具有很好的可靠性和健壮性。
ReversePolish() 函数接受一个指向符号表达式语法树根节点的指针和一个 vector,用于存储生成的后缀表达式。
其中:
ReversePolish() 函数采用后序遍历算法,先处理左子节点,然后是右子节点,最后才处理根节点。
当 root_node 不为空时,递归调用 ReversePolish() 函数以处理其 left 和 right 子节点,使得 vector 中保存的各个节点的次序即根据后缀表达式的次序排列。
对于根节点,将其推入 vector 的末尾,即实现了后缀表达式的构建。
该函数的主要作用是将一个给定符号表达式语法树转换成后缀表达式,并保存在 vector 中,方便后续基于后缀表达式的计算。
通常在表达式计算场景中,后缀表达式具有更好的可处理性和操作性,因此其转换成后缀表达式后可更快地求值。
我们定义了一个名为 ExpressionLayer 的类,继承自 Layer 类,用于将输入的表达式字符串解析为后缀表达式,并对给定的张量进行基于后缀表达式的计算。
其中:
ExpressionLayer 的构造函数接受一个字符串表达式 statement,创建一个 ExpressionParser 对象。
ExpressionLayer 的 Forward() 函数是重写 Layer 接口的纯虚函数,输入参数为一个浮点数类型的张量列表 inputs,输出参数为一个浮点数类型的张量列表 outputs。Forward()函数通过 ExpressionParser 对象解析出的后缀表达式计算结果,并更新到 outputs 参数。
GetInstance() 是 ExpressionLayer 的一个静态成员函数,用于创建 ExpressionLayer 对象并返回一个指向该对象的智能指针。该函数接受一个 RuntimeOperator 指针对象作为参数,并利用 op->param_attr 中包含的表达式串创建ExpressionLayer 对象,然后返回该对象的智能指针。
ExpressionLayer 类主要作用是将用户输入的表达式字符串解析为后缀表达式,并提供基于后缀表达式的计算功能。
ExpressionLayer 类中的静态成员函数 GetInstance() 用于从 op 中解析出表达式字符串,创建 ExpressionLayer 对象,并返回 ParseParameterAttrStatus 枚举类型。
首先,检查 RuntimeOperator 对象 op 是否为空,如果为空则抛出异常。
然后,检查 params map 中是否包含 “expr” 字符串作为 key 。如果不包含,则返回 kParameterMissingExpr 枚举。
接下来,检查 params[“expr”] 对象是否为语句字符串类型的运行参数。如果不是,则返回 kParameterMissingExpr 枚举。
最后,通过 std::make_shared 创建 ExpressionLayer 指针对象,传入 statement_param->value 来初始化 ExpressionLayer 类的一个新实例;同时返回 kParameterAttrParseSuccess 枚举,说明解析成功。
此外, LayerRegistererWrapper kExpressionGetInstance("pnnx.Expression", ExpressionLayer::GetInstance)
使得 ExpressionLayer::GetInstance() 函数在工厂类中注册绑定到名称为 “pnnx.Expression”,以便根据之前已经被注册的名称获取对应的层实例。
ExpressionLayer 类的成员函数 Forward() 用于根据输入张量列表计算输出张量,并返回 InferStatus 枚举类型。
首先,对输入和输出张量进行检查。如果输入张量为空或者输出张量为空,则分别返回 kInferFailedInputEmpty 和 kInferFailedInputOutSizeAdaptingError 两种错误码。
随后,检查解析器对象指针是否为空,表达式解析
并获取 tokens,如果解析失败,输出错误。
接下来,检查所有的输入张量是否为非空且一个维度中至少有一个元素,如果条件不满足,则返回 kInferFailedInputEmpty 错误码。
使用 batch_size 获取输出张量数量,并在遍历输出张量时对每个张量进行初始化(赋值为0)。
然后创建一个栈用于保存从张量张量取出来的操作数。接着,构建语法树
,之后进入循环,依次遍历 TokenNode 树中的每个节点。
最后,当遍历结束时,此时栈中仅剩下一个元素,即为计算结果。将其出栈,并将输出张量列表中的操作数更新为该元素即可。
ExpressionLayer实现了逆波兰表达式的计算过程,支持二元运算符加、乘,并能够处理 batch 式数据。