mxnet代码剖析之--Symbol篇

Ver2.0:

struct Node:
std::unique_ptr op; /// 节点操作函数类,变量节点指针为空
std::string name; /// 节点名称
std::vector inputs; /// 节点的输入口,包括上层节点的输出,变量节点等
std::shared_ptr backward_source_node; /// 如果支持backward,表示反向计算时的下一级节点
std::unique_ptr > attr;

类说明:
1 每个Node对应python中的一个symbol,变量符号/操作符号
2 操作符号应该包含操作函数,输入符号,输出符号,如果允许backward计算,包括反向source符号
3 变量符号操作函数句柄为空,并且不支持backward计算

struct DataEntry
std::shared_ptr source; /// 节点指针
uint32_t index; /// 节点索引号,表示DFS遍历生成的序号

Class Symbol:
std::vector heads_;

类说明:
1 提供与python语言symbol之间的接口,表达符号计算图(网络)
2 heads中每个元素表示当前符号的输出节点,操作子一般输出节点个数为1,除以下节点:
batch_norm: output, mean, var
dropout: output, mask
leaky_relu: output, mask
lrn: output, tmp_norm
3 符号的变量一般包括源符号的输出output,以及本节点的参数(weight, bias, mask 等)
4 符号输出ndarray的shape/type一般由输入ndarray以及操作类型共同决定

部分函数接口说明:
void Symbol::DFSVisit(FVisit fvisit):实现函数fvisit对所有节点的深度优先遍历访问
void KeywordArgumentMismatch():匹配用户输入变量与内部变量,实现变量key检查
int FindDuplicateArgs(): 遍历网络,找到重复出现的变量,返回最大重复值
std::vector ListArguments():遍历网络,返回所有的变量节点
std::vector ListOutputs():遍历网络,返回所有的输出节点
Symbol operator[](size_t index): 返回第index个output节点
void Compose(args, name):
1 变量节点不支持组合操作
2 所有的args节点有且仅有一个输出节点
3 如果父节点是原子符号:将所有的args节点赋值到heads_[0].source.inputs,其余部分使用默认变量补齐
4 否则对比已有输入与args,用args替换已有输入节点
void ToStaticGraph(StaticGraph *out_graph) :转换为StaticGraph结构

/// ===================================================================================
1 关于Node
1.1 网络结构中的基本单元
1.2 包括操作入口,唯一标识符,输入列表,如果是反向传递节点,backward_source_node指向相应的正向节点
1.3 变量节点操作句柄为空,正向传递节点backward_source_node为空
1.4 添加atomic属性,即所有未添加输入参数的非变量节点。为了实际应用中的方便,无实际意义

2 关于Symbol
2.1 每个符号维护一个输出队列容器,容器大小由NumOutputs()决定,其中第一个输出变量为网络中传递变量(下一节点输出变量),其它变量做为辅助变量,用于反向梯度计算
2.2 对于所有的包含操作的符号,NumVisibleOutputs = 1,即仅支持单输出链接(其它输出仅做为反向传递时的辅助变量)
2.3 符号compose,组合输入输出到操作(不支持变量符号)


==================================================================================================================================

Ver1.0


struct Node:

std::unique_ptr op; /// 节点操作函数类,变量节点指针为空
std::string name; /// 节点名称
std::vector inputs; /// 节点的输入口,包括上层节点的输出,变量节点等
std::shared_ptr backward_source_node; /// 如果支持backward,表示反向计算时的下一级节点
std::unique_ptr > attr;

类说明:
1 每个Node对应python中的一个symbol,变量符号/操作符号
2 操作符号应该包含操作函数,输入符号,输出符号,如果允许backward计算,包括反向source符号
3 变量符号操作函数句柄为空,并且不支持backward计算
/// --------------------------------------------------------------------------------------------------------------------------
struct DataEntry
std::shared_ptr source; /// 节点指针
uint32_t index; /// 节点索引号,表示DFS遍历生成的序号


/// --------------------------------------------------------------------------------------------------------------------------

Class Symbol:
std::vector heads_;

类说明:
1 提供与python语言symbol之间的接口,表达符号计算图(网络)
2 heads中每个元素表示当前符号的输出节点,操作子一般输出节点个数为1,除以下节点:
batch_norm: output, mean, var
dropout: output, mask
leaky_relu: output, mask
lrn: output, tmp_norm
3 符号的变量一般包括源符号的输出output,以及本节点的参数(weight, bias, mask 等)
4 符号输出ndarray的shape/type一般由输入ndarray以及操作类型共同决定

部分函数接口说明:
void Symbol::DFSVisit(FVisit fvisit):实现函数fvisit对所有节点的深度优先遍历访问
void KeywordArgumentMismatch():匹配用户输入变量与内部变量,实现变量key检查
int FindDuplicateArgs(): 遍历网络,找到重复出现的变量,返回最大重复值
std::vector ListArguments():遍历网络,返回所有的变量节点
std::vector ListOutputs():遍历网络,返回所有的输出节点
Symbol operator[](size_t index): 返回第index个output节点
void Compose(args, name):
1 变量节点不支持组合操作
2 所有的args节点有且仅有一个输出节点
3 如果父节点是原子符号:将所有的args节点赋值到heads_[0].source.inputs,其余部分使用默认变量补齐
4 否则对比已有输入与args,用args替换已有输入节点
void ToStaticGraph(StaticGraph *out_graph) :转换为StaticGraph结构

你可能感兴趣的:(深度学习,mxnet)