4.自动求导框架架构

本节是自动求导框架技术的第四节,本系列其余文章包括


自动求导框架综述

1. 矩阵求导

2. 链式法则与计算图

3. 控制流与其实现思路

5. 使用自动求导框架实现RNN


自动求导框架主要包括以下几个类:

Tensor:封装了张量以及张量的基本运算,也就是一个任意维度的数组。

Graph:图类,是虚拟图和计算图的基类,包括了图的基本算法:构件图的邻接表,图的转置,根据给定的节点对图进行剪枝,图的拓扑排序等。

Node:节点类,是虚拟节点,计算节点等节点的基类,是Graph的组成部分。

VirtualGraph:虚拟图类,主要包含一个前向传播构建计算图的方法 build_compute_graph。

VirtualNode:虚拟节点类,组成虚拟图,内部包含了一个计算节点工厂,可以根据虚拟节点的类型生成计算节点。每个虚拟节点生成的计算节点引用都会被缓存在虚拟节点中,使得用户可以通过虚拟节点访问到生成的计算节点;虚拟节点还可以控制生成的计算节点是否共享参数,默认是不共享。共享参数这一功能对于rnn的构造是很有必要的。

LoopNode 和 BranchNode,虚拟图中的控制节点,用于实现控制流。其中 LoopNode 内包含了一个子虚拟图。

ComputeGraph:计算图,由虚拟图构造出来,包括了前向传播和反向求导等功能。之所以和虚拟图VirtualGraph 一样包含了前向传播的功能,是因为大部分网络对于不同的数据实际上不会拥有动态的网络结构,计算图由虚拟图构造出来之后结构就不会变化了。所以计算图提供前向传播的功能是为了优化的考虑。

OperatorNode:计算节点的基类,主要包括 op () 前向计算函数和 grad_op () 求解梯度的函数,其子类包括参数节点 Parameter 类,数据输入节点 Input 类,各种运算比如基于张量的加法,减法,乘法,偏置,sigmoid 函数,二阶范数,一阶范数。


4.自动求导框架架构_第1张图片
自动求导框架类图

你可能感兴趣的:(4.自动求导框架架构)