TensorFlow图优化(一)-CSE(公共子表达式消除)

    TensorFlow中有很多图优化,包括公共的优化策略和针对设备的特殊优化,有兴趣可以针对自己的网络设计专门的优化。对于大型公司来说,为了提高计算效率进行专门优化很有必要,对于个人学者来说,不如买个更牛逼的卡。

一、图优化的执行时序

    图优化的入口是第一次创建Session后,执行Sess Run时会创建executor,即调用CreateExecutors()。这个接口中有两个分支进行图优化,其中一个分支实例化GraphOptimizer父类,父类中完成了CSE和ConstantFold等图优化;另一个分支进行了更多的优化,包括针对CPU、GPU等设备的专用优化,如内存优化、GPU转化NHWC为NCHW等。这篇介绍CSE的执行时序和优化原理。

    懒得画图,CSE优化的执行时序如下:

CreateExecutors实例化GraphOptimizer-->调用optimizer.Optimize-->判断option中是否开启CSE,开启则调用OptimizeCSE-->实例化OptimizerCSE,调用其Optimize方法开始优化。

    其中consider_fn是判断节点是否可以合并的函数指针,默认为空,实际也没传值,因此这里不需要考虑。

二、优化原理

    看到源码前,没有想到实现如此简洁。CSE的目的是将相同输入的表达式进行消除,由一个节点来代替,复用计算结果。源码如下。

TensorFlow图优化(一)-CSE(公共子表达式消除)_第1张图片

1、得到图的逆后续节点集,代码实现中,首先得到图的深度优先遍历,然后将结果逆转,就得到了逆后续节点集,开发人员没有给出使用逆后续的原因,个人考虑,逆后续得到的结果就是拓扑排序,即访问到某一个节点时,该节点的依赖节点都已经被访问。

2、创建一个map,名为available,作用是存储候选集,在后续遍历图中节点时,可以从候选集中查找是否有可以使用的表达式。

3、遍历图所有节点。

4、判断表达式相同的关键思想。获取节点的hash值,hash的key由输出个数+输出类型s+(输入节点id+端口)s,这样设计key可以保证输入相同及输出个数和类型相同时,得到的hash值相同,达到检索公共表达式的目的,而hash的特点就是可以在O(1)的时间复杂度根据key得到value。

5、根据hash值h从map中得到candidate,当candidate是空时,表示第一次遇到这样的表达式,将节点存入map中。

6、做candidate非空,则说明之前遍历到了相似的表达式,接下来需要进一步判断是否可复用已保存节点表达式的结果。Equivalent函数完成了这样的判断,大致是节点类型不同不可,输入非const不可,输入类型中有ref类型不可(ref类型不使用,enum大于100),节点attrs不同不可,输入个数不同不可,输入节点id和端口不同不可。即节点的输入都是来自相同的const节点,可以保证输入数据完全相同,而输出个数和类型相同可以保证输出的结果相同,因此可以复用之前保存节点的结果,不用再次进行计算。

7、判断表达式可以复用后,最后一步就是删除重复的节点,方法很简单,直接将candidate的输出连接到当前节点的输出节点对应的输入端口,最后删除当前节点即可。

 

    之前看的一点内容,粗略记录一下,不记录哪天就忘记了。有偏颇的地方难以避免。以后有需要再完善。

你可能感兴趣的:(TensorFlow源码解析,TensorFlow,图优化,CSE)