mxnet代码解析之nnvm

概述

nnvm启发于LLVM,它利用operator的高层信息去优化计算图;nnvm是从mxnet的实现中剥离出来一个模块,该模块完成了从symbol描述的网络到graph描述的符号计算图的生成和优化工作,而这样的模块化剥离仿效了unix的哲学,使得mxnet能够在不同的设备应用和场景中自主裁剪各功能模块。

nnvm中的graph包含了计算图的结构,并且包含了一个从字符串到任意类型的属性映射map< string, shared_ptr< any > >,这个属性映射包含了每一个tensor的shape、type以及内存分配计划。

nnvm中的pass就是对包含了各种属性映射信息的计算图执行转换,转换使得该计算图拥有更多地属性或者变为另一个计算图。nnvm中实现的pass包括自动差分计算、形状和类型推断、内存计划等。

nnvm采用NNVM_REGISTER_OP去注册一个operator,并可以对不同的op用set_attr注册不同的属性,使得不同operator的实现不必采用同一个operator接口,完成了去中心化的设计目标,使得不同框架下的operator实现都可以采用nnvm做计算图优化。

mxnet v0.7中graph的解析参照http://blog.csdn.net/chaojichaoachao/article/details/52026799。

Op

//一个Op就是一个操作,一个Node对应一个Op。
class Op {
 public:
  std::string name;
  //operator的描述,可用于自动生成docstring
  std::string description;
  //输入和关键参数的描述
  std::vector arguments;
  uint32_t num_inputs = 1;
  uint32_t num_outputs = 1;
  std::functionconst NodeAttrs& attrs)> get_num_outputs = nullptr;
  std::functionconst NodeAttrs& attrs)> get_num_inputs = nullptr;
   //解析属性的函数指针,该函数将解析结果放到attrs->parsed中去,可以调用
   //nnvm::get<参数类型>获取属性的该参数。
  std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
  inline Op& describe(const std::string& descr);  // NOLINT(*)
  inline Op& add_argument(const std::string &name,
                          const std::string &type,
                          const std::string &description);
  inline Op& add_arguments(const std::vector &args);
  inline Op& set_num_inputs(uint32_t n);  // NOLINT(*)
  inline Op& set_num_inputs(std::functionconst NodeAttrs& attr)> fn);  // NOLINT(*)
  inline Op& set_num_outputs(uint32_t n);  // NOLINT(*)
  inline Op& set_num_outputs(std::functionconst NodeAttrs& attr)> fn);  // NOLINT(*)
  inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn);  // NOLINT(*)
  template<typename ValueType>
  inline Op& set_attr(const std::string& attr_name,  // NOLINT(*)
                      const ValueType& value,
                      int plevel = 10);
  Op& add_alias(const std::string& alias);  // NOLINT(*)
  //从一个注册好的op group中将所有属性包含进来
  Op& include(const std::string& group_name);
  static const Op* Get(const std::string& op_name);
  template<typename ValueType>
  static const OpMap& GetAttr(const std::string& attr_name);

 private:
  template<typename ValueType>
  friend class OpMap;
  friend class OpGroup;
  friend class dmlc::Registry;
  uint32_t index_{0};
  Op();
  static const any* GetAttrMap(const std::string& key);
  static void UpdateAttrMap(const std::string& key,
                            std::function<void(any*)> updater);
  //基于合适的tag属性匹配添加一个触发器,注册时调用了include的op也都会应用这个触发器
  static void AddGroupTrigger(const std::string& group_name,
                              std::function<void(Op*)> trigger);
};

Node

using NodePtr = std::shared_ptr;

//表示一个节点输出数据的项,一个NodeEntry以node的其中一个输出的视角来描述
struct NodeEntry {
  //数据的源节点
  NodePtr node;
  //该输出的索引
  uint32_t index;
   //输入变量的版本
   //node是一个变量节点时,这个值只能是非0,变量每次参与一个修改Op时version都会加1,
   //这个信息在一个修改序列发生时对于决定操作的顺序有帮助
  uint32_t version;
};
struct NodeAttrs {
  const Op *op{nullptr};
  std::string name;
  //位置属性的向量表示
  std::vector<double> scalars;
  //属性的字典表示
  std::unordered_map<std::string, std::string> dict;
   //any是任意类,如果注册了OpProperty.attr_parser就会生成,可快速访问属性
  any parsed;
};
class Node {
 public:
  NodeAttrs attrs;
  std::vector inputs;
   //当前节点操作之前必须完成的操作
  std::vector control_deps;
  ~Node();
  inline const Op* op() const;
  inline bool is_variable() const;
  inline uint32_t num_outputs() const;
  inline uint32_t num_inputs() const;
  static NodePtr Create();
};

Any

前面提到的any可以表示任意类,与c++17的 std::any兼容,它的定义如下:

class any {
 public:
  inline any() = default;
  inline any(any&& other);  // NOLINT(*)
  inline any(const any& other);  // NOLINT(*)
  template<typename T>
  inline any(T&& other);  // NOLINT(*)
  inline ~any();
  inline any& operator=(any&& other);
  inline any& operator=(const any& other);
  template<typename T>
  inline any& operator=(T&& other);
  inline bool empty() const;
  inline void clear();
  inline void swap(any& other); // NOLINT(*)
  inline const std::type_info& type() const;

 private:
  template<typename T>
  class TypeOnHeap;
  template<typename T>
  class TypeOnStack;
  template<typename T>
  class TypeInfo;
  //栈空间大小,一个any类型32比特
  static const size_t kStack = sizeof(void*) * 3;
  static const size_t kAlign = sizeof(void*);
  // container use dynamic storage only when space runs lager
  //当空间变得更大时容器只使用动态内存
  union Data {
    std::aligned_storage::type stack;// 栈空间    
    void* pheap;// 指向堆空间
  };
  struct Type {
    void (*destroy)(Data* data);
    void (*create_from_data)(Data* dst, const Data& src);
    const std::type_info* ptype_info;
  };
  //检查数据是否能存储在堆空间的常数
  template<typename T>
  struct data_on_stack {
    static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack;
  };
  template<typename T>
  friend T& get(any& src);  // NOLINT(*)
  template<typename T>
  friend const T& get(const any& src);
  inline void construct(any&& other);
  inline void construct(const any& other);
  template<typename T>
  inline void check_type() const;
  const Type* type_{nullptr};
  // 核心数据
  Data data_;
};

any类型中包含了右值引用的使用,如下是三个对operator=的重载:

inline void any::swap(any& other) { // NOLINT(*)
  std::swap(type_, other.type_);
  std::swap(data_, other.data_);
  }
  inline any& any::operator=(any&& other) {
  any(std::move(other)).swap(*this);
  return *this;
}

inline any& any::operator=(const any& other) {
  any(other).swap(*this);
  return *this;
}

template<typename T>
inline any& any::operator=(T&& other) {
  any(std::forward(other)).swap(*this);
  return *this;
}
}

any可以安全地表示任意类型,因此NodeAttrs中的parsed可以是任意类型的param,在需要时就可以通过nnvm::get< Type >取出来。

Graph

//symbol是面向前端接收前端定义的网络,然后在后端转化成优化需要的计算图Graph:
class Symbol {
 public:
 ...
 ...

  //输出项,对应于原来的 heads_
  std::vector outputs;
 ...
  void Compose(const array_view<const Symbol*>& args,
               const std::unordered_map<std::string, const Symbol*>& kwargs,
               const std::string& name);

  void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);

  bool GetAttr(const std::string& key, std::string* out) const;
  ...
  //Variable, Functor and Group三种symbol组件的创建
  static Symbol CreateFunctor(const Op* op,
                              std::unordered_map<std::string, std::string> attrs);
  static Symbol CreateVariable(const std::string& name);
  static Symbol CreateGroup(const std::vector& symbols);
};

符号计算图

class Graph {
 public:
  std::vector outputs;
   //高度推荐保持每个属性不可变,这样实现写时拷贝的场景就安全了。shared_ptr.unique为true
   //则拷贝,为真则重复利用原始空间。
  std::unordered_map<std::string, std::shared_ptr > attrs;
  template<typename T>
  inline const T& GetAttr(const std::string& attr_name) const;
   //获取属性的移动拷贝,实现了写时拷贝的场景。引用计数为1时在调用后从attrs中擦除。
  template<typename T>
  inline T MoveCopyAttr(const std::string& attr_name);
  const IndexedGraph& indexed_graph();

 private:
  // 索引图的内部结构
  std::shared_ptr<const IndexedGraph> indexed_graph_;
};

class IndexedGraph {
 public:
  struct NodeEntry {
    uint32_t node_id;//区别于nnvm::NodeEntry,这里只是索引
    uint32_t index;
    uint32_t version;
  };
  struct Node {
    const nnvm::Node* source;
    array_view inputs;
    array_view control_deps;
  };
  。。。
  inline size_t num_node_entries() const {
    return entry_rptr_.back();
  }
  //获取一个0~num_node_entries()的唯一entry id
  inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
    return entry_rptr_[node_id] + index;
  }
  。。。
   //给定node_id获取对应的节点结构
  inline const Node& operator[](uint32_t node_id) const {
    return nodes_[node_id];
  }
   //返回对应的IndexedGraph::Node的常引用
  inline const Node& operator[](const nnvm::Node* node) const {
    return nodes_[node_id(node)];
  }
  。。。
  。。。
  IndexedGraph(const IndexedGraph&) = delete;
 private:
  friend class Graph;
  。。。
  // Node pointers in CSR structure.
  std::vector nodes_;
  std::vector input_nodes_;
  std::unordered_set mutable_input_nodes_;
  std::vector outputs_;
  std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
  // CSR pointer of node entries
  std::vector entry_rptr_;
  std::vector input_entries_;
  std::vector control_deps_;
};

array_view为dmlc中定义的只读数组,用来访问连续内存,它为vector、array、c stype array 提供了统一的视角,这个数据结构不保证它所引用的数组的活动性,因此不要用它在异步函数闭包中记录数据,也不要用它创建临时数据结构的引用。

Pass

pass函数是一个graph上的operator,其定义和应用函数如下:

typedef std::function PassFunction;
Graph ApplyPasses(Graph src,const std::vector<std::string>& passes);
struct PassFunctionReg
    : public dmlc::FunctionRegEntryBase {
  bool change_graph{false};
  std::vector<std::string> op_attr_dependency;
  std::vector<std::string> graph_attr_dependency;
  std::vector<std::string> graph_attr_targets;
  PassFunctionReg& set_change_graph(bool v) {  // NOLINT(*)
    change_graph = v;
    return *this;
  }
   //声明这个pass一旦应用到graph上将会生成给定的graph属性名称
  PassFunctionReg& provide_graph_attr(const std::string& attr_name) {  // NOLINT(*)
    graph_attr_targets.push_back(attr_name);
    return *this;
  }
   //声明这个pass要求给定的operator属性在应用到graph上之前保证可用
  PassFunctionReg& depend_op_attr(const std::string& attr_name) {  // NOLINT(*)
    op_attr_dependency.push_back(attr_name);
    return *this;
  }
   //声明这个pass要求给定的graph属性在应用到graph上之前保证可用
  PassFunctionReg& depend_graph_attr(const std::string& attr_name) {  // NOLINT(*)
    graph_attr_dependency.push_back(attr_name);
    return *this;
  }
};

当前的pass函数包括

inline Graph LoadJSON(const std::string& json_str);
inline std::string SaveJSON(Graph graph);
//强制规范了正确的读和写的顺序,解决读后写以及写后读的问题
inline Graph OrderMutation(Graph src);
inline Graph InferShape(Graph graph,ShapeVector shape_inputs,std::string shape_attr_key = "");
inline Graph InferType(Graph graph,DTypeVector dtype_inputs,std::string dtype_attr_key = "");
inline Graph PlaceDevice(Graph graph,std::string device_group_attr_key,
           DeviceAssignMap device_assign_map,std::string device_copy_op){
  graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key));
  graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map));
  graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op));
  return ApplyPass(std::move(graph), "PlaceDevice");
}
inline Graph Gradient(
    Graph graph,
    std::vector ys,
    std::vector xs,
    std::vector ys_out_grad,
    std::functionstd::vector&& inputs)> aggregate_fun = nullptr,
    std::function<int(const Node& node)> mirror_fun = nullptr,
    std::functionconst NodeEntry& src, const NodeEntry &like)>
    attr_hint_fun = nullptr) {
  graph.attrs["grad_ys"] = std::make_shared(std::move(ys));

  graph.attrs["grad_xs"] = std::make_shared(std::move(xs));
  graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad));
  if (aggregate_fun != nullptr) {
    graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun);
  }

  if (mirror_fun != nullptr) {
    graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun);
  }

  if (attr_hint_fun != nullptr) {
    graph.attrs["attr_hint_fun"] = std::make_shared(attr_hint_fun);
  }

  return ApplyPass(std::move(graph), "Gradient");
}

gradient实现了自动求导,其主要代码如下,其中注释给出了主要功能块的解释:

//这个是将一个节点的多个输出聚合成一个sum_node
NodeEntry DefaultAggregateGradient(std::vector&& v) {
  if (v.size() == 1) {
    return std::move(v[0]);
  } else if (v.size() == 0) {
    NodePtr zero_node = Node::Create();
    zero_node->attrs.op = Op::Get("__zero__");
    return NodeEntry{zero_node, 0, 0};
  } else {
    NodePtr sum_node = Node::Create();
    sum_node->attrs.op = Op::Get("__ewise_sum__");
    sum_node->inputs = std::move(v);
    return NodeEntry{sum_node, 0, 0};
  }
}

//这个类在计算梯度的过程中临时保存一个节点的梯度
struct GradEntry {
#ifdef _MSC_VER
  NodeEntry sum = NodeEntry{nullptr, 0, 0};
#else
  NodeEntry sum{nullptr, 0, 0};
#endif
  std::vector grads;
  bool need_attr_hint{true};
};

Graph Gradient(Graph src) {
  using nnvm::FGradient;
  using MirrorFun = std::function<int (const Node& node)>;
  using AttrHintFun = std::functionconst NodeEntry& src, const NodeEntry &like)>;

  CHECK_NE(src.attrs.count("grad_ys"), 0)
      << "Gradient require grad_ys to be presented.";
  CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0)
      << "Gradient require grad_ys_out_grad to be presented.";
  CHECK_NE(src.attrs.count("grad_xs"), 0)
      << "Gradient require grad_xs to be presented.";
  //xs,ys,ys_out_grad分别对应于源graph中每一个需要计算梯度节点的输入节点、输出节点
  //以及有梯度传回该节点的节点
  const std::vector& ys =
      src.GetAttr<std::vector >("grad_ys");
  const std::vector& ys_out_grad =
      src.GetAttr<std::vector >("grad_ys_out_grad");
  const std::vector& xs =
      src.GetAttr<std::vector >("grad_xs");
  using AggFun = std::functionstd::vector&& inputs)>;
  AggFun agg_fun = DefaultAggregateGradient;
  if (src.attrs.count("grad_aggregate_fun") != 0) {
    agg_fun = src.GetAttr("grad_aggregate_fun");
  }
  MirrorFun mirror_fun = nullptr;
  if (src.attrs.count("grad_mirror_fun") != 0) {
    mirror_fun = src.GetAttr("grad_mirror_fun");
  }
  AttrHintFun attr_hint_fun = nullptr;
  if (src.attrs.count("attr_hint_fun") != 0) {
    attr_hint_fun = src.GetAttr("attr_hint_fun");
  }

  // topo sort
  std::vector topo_order;
  //这是一个用来保存梯度的临时map,该map的key集合是所有的输出节点,value是key的所有输出节点
  std::unordered_mapstd::vector > output_grads;

  DFSVisit(ys, [&](const NodePtr& node) {
      if (output_grads.count(node.get()) == 0) {
        output_grads[node.get()].resize(node->num_outputs());
      }
      topo_order.push_back(node);
    });

  CHECK_EQ(ys.size(), ys_out_grad.size());
 //每一个节点产生了几个输出,就会回传几个梯度,output_grads从源graph的ys_out_grad中得,
 //后面再与已计算出回传到该节点的梯度的节点聚合;每一个输出ys[i]只对应一个输入node,
 //但是ys[i]的输入节点可有很多的输出
  for (size_t i = 0; i < ys.size(); ++i) {
    NodeEntry ograd = ys_out_grad[i];
    output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
  }

  // 用于构建镜像函数以节省内存,如果需要的话
  std::unordered_map mirror_map;
  if (mirror_fun != nullptr) {
    for (const NodePtr& n : topo_order) {
      if (mirror_fun(*n)) {
        NodePtr new_node = Node::Create();
        *new_node = *n;
        new_node->attrs.name += "_mirror";
        for (auto& e : new_node->inputs) {
          e.node = mirror_map.at(e.node.get());
        }
        for (auto& n : new_node->control_deps) {
          n = mirror_map.at(n.get());
        }
        mirror_map[n.get()] = std::move(new_node);
      } else {
        mirror_map[n.get()] = n;
      }
    }
  }

  // 遍历backward
  static auto& grad_fun_map = Op::GetAttr("FGradient");
  static auto& finfer_shape = Op::GetAttr("FInferShape");

  std::vector out_agg_grads;
  //从后往前计算并传递梯度
  for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
    const NodePtr& ptr = *rit;
    if (ptr->is_variable()) continue;
    out_agg_grads.clear();
    auto& out_grad_vec = output_grads.at(ptr.get());
    //将当前node的所有输出传回的梯度节点相加
    for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
      GradEntry& e = out_grad_vec[i];
      e.sum = agg_fun(std::move(e.grads));
      if (e.need_attr_hint && attr_hint_fun != nullptr) {
        e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
      }
      out_agg_grads.push_back(e.sum);
    }
    if ((*rit)->inputs.size() != 0) {
      NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
      //核心调用,从out_agg_grads得到input_grads ,grad_fun_map中保存了各种Op的backward函数
      std::vector input_grads = grad_fun_map[ptr->op()](
          fwd_node, out_agg_grads);
      CHECK_EQ((*rit)->inputs.size(), input_grads.size())
          << "Gradient function not returning enough gradient";
      auto git = input_grads.begin();
      //将核心调用得到的当前节点需要传给每一个输入节点对应的梯度节点写入output_grads结构中
      for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
        auto& ge = output_grads[it->node.get()][it->index];
        //如果该节点的backward操作能做shape的推断操作,就不需要调用hint函数了
        if (finfer_shape.count(git->node->op())) {
          ge.need_attr_hint = false;
        }
        ge.grads.emplace_back(std::move(*git));
      }
    }
  }
  // 将output_grads导入输出graph的outputs中,每一项与xs一一对应
  Graph ret;
  ret.outputs.reserve(xs.size());
  for (const NodeEntry& e : xs) {
    GradEntry& entry = output_grads[e.node.get()][e.index];
    // aggregate sum if there haven't been
    if (entry.sum.node.get() == nullptr) {
      entry.sum = agg_fun(std::move(entry.grads));
      if (entry.need_attr_hint && attr_hint_fun != nullptr) {
        entry.sum = attr_hint_fun(entry.sum, e);
      }
    }
    ret.outputs.emplace_back(std::move(entry.sum));
  }
  return ret;
}

// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");

除了上述pass函数外,还有一个PlanMemory函数用于内存计划,这个函数需要在GraphExecutor的InitGraph中手动调用ApplyPass函数,g = nnvm::ApplyPass(g, “PlanMemory”)。而真正的内存分配在GraphExecutor的InitDataEntryMemory函数中,该函数将按照内存计划执行分配。

连接前后端

mxnet在c_api_xx源码文件中提供了接口供前端调用,mxnet的前端支持python、R、Scala、Go等一众语言,这些接口将采用nnvm::Symbol接收前端给出的中间计算表示形式,以如下函数为例:

int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
                               mx_uint num_param,
                               const char **keys,
                               const char **vals,
                               SymbolHandle *out) {
  nnvm::Symbol *s = new nnvm::Symbol();
  API_BEGIN();
  const nnvm::Op* op = static_cast<const nnvm::Op*>(creator);
  std::unordered_map<std::string, std::string> kwargs;
  for (nn_uint i = 0; i < num_param; ++i) {
    bool flag = false;
    for (const auto &k : kHiddenKeys) {
      std::string tmp(keys[i]);
      size_t pos = tmp.rfind(k);
      if (pos == 0) {
        kwargs.insert({"__" + tmp + "__", std::string(vals[i])});
        flag = true;
        break;
      } else if (pos != std::string::npos && pos == tmp.length() - k.length()) {
        std::ostringstream os;
        os << "setting variable attributes with " << keys[i] << " is deprecated. "
           << "please instead use\nw = Variable(" << k << "=" << vals[i] << ")\n"
           << "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)";
        throw dmlc::Error(os.str());
      }
    }
    if (!flag)
      kwargs.insert({std::string(keys[i]), std::string(vals[i])});
  }
  *s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs));
  *out = s;
  API_END_HANDLE_ERROR(delete s;);
}

该接口函数根据kwargs创建symbol,kwargs中保存的是属性映射信息的string表达,在CreateFunctor中解析出来。
后端包含了各种功能的operator在不同设备下的实现,在定义operator实现时通过NNVM_REGISTER_OP和MXNET_REGISTER_OP_PROPERTY宏注册,NNVM_REGISTER_OP是新的operator注册方法,MXNET_REGISTER_OP_PROPERTY是老版本mxnet的注册方法,convolution、activation、fullyconnected等都是通过MXNET_REGISTER_OP_PROPERTY注册进Registry类中,新版本的mxnet将在前端间接调用RegisterLegacyOpProp函数将Registry中注册的operator转换到nnvm registry中。

本文将持续更新……

你可能感兴趣的:(deeplearning,c++,DL)