MxNet源码解析(1) KVStore,pslite源码解析

1. 前言

从毕业开始工作已经两个多月,这期间相当一部分的时间都用在了对MxNet的学习上,而在MxNet的众多部分中,又是pslite这一部分接触最多。因此,今天将我一直以来的学习过程中的心得和收获总结在这里,也为以后对MxNet的继续学习做一个铺垫

2. MxNet构成

MxNet作为一个深度学习框架,它最大的特点应该是分布式训练的支持了。从初次接触MxNet到现在的两个多月里,我认为MxNet主要有以下几个大的部分:

  • symbol和graph,负责构建计算的图和反向传播的图
  • Engine,负责根据图的节点的依赖关系,并行运算
  • Parameter server和KVStore,负责参数的同步和传递
  • operator,定义图的节点的op
  • NDArray,数据的存储和计算
  • Executor,对图进行处理用于计算

3. Parameter server

参数服务器的概念并不复杂,主要思想就是,将模型的参数保存在server中,另外通过worker来完成具体的计算任务,当每完成一个计算任务时就会得到对应参数的梯度,这时将梯度传送给server,由server来完成参数的更新,worker再从server那里取回更新后的参数。
在MxNet中,当我们需要进行分布式的训练时,就需要使用到它了。在MxNet中,为了完成参数在不同机器前的同步和更新,主要实现了两大部分。一是pslite,另一个是KVStore

3.1 KVStore

为了更方便理解,我从KVStore开始讲起。在MXNet中,可能很多人并不会直接操作KVStore,在官方文档中,甚至提到,不建议直接操作KVStore,但是,每个人在使用MXNet的过程中,都肯定用到了KVStore。其实,在我们建立module.Module的时候,就会调用KVStorepushpull操作。
kvstore主要分为两种,一种是单机下,一种是多机下。单机下又分为将参数存储在GPU显存和CPU内存上两种情况。

3.1.1 comm.h

comm.h文件中定义了Comm类,该类用于设备间的信息传递,也就是communication。从Comm类中派生出了两个子类CommCPU用于CPU内存通信,CommDevice用于GPU通信。

Comm类中定义了几个纯虚函数:

  • Init:根据存储类型和shape初始化
  • Reduce:输入NDArray的一个vector,返回它们的和
  • Broadcast:将一个NDArray复制到vector中的每一个元素
  • BroadcastRowSparse

CommCPU

将数据复制到CPU内存中,在那里做操作。

  • Init:初始化key对应的KVStore,创建key对应的NDArray,保存在merge_buf_[key].merged中。(不分配内存)
  • Reduce:将输入的vector& src的每个元素求和并返回。当src只有一个元素时,若不是sparse的就直接返回src[0],若是则将src[0]拷贝至merged_buf返回。如果src元素多于一个,那么:
if (stype == kDefaultStorage) {
      std::vector const_vars(src.size() - 1); // 定义engine pushasync的输入,用于engine根据该操作的输入来规划操作的执行
      std::vector reduce(src.size());
      CopyFromTo(src[0], &buf_merged, priority);
      reduce[0] = buf_merged;

      if (buf.copy_buf.empty()) { // copy_buf用于GPU数据拷贝至CPU,由于第0个元素存储在buf_merged,这里只需要src.size()-1个
        buf.copy_buf.resize(src.size()-1);
        for (size_t j = 0; j < src.size() - 1; ++j) {
          // allocate copy buffer
          buf.copy_buf[j] = NDArray(
            src[0].shape(), pinned_ctx_, false, src[0].dtype());
        }
      }
      CHECK(stype == buf.copy_buf[0].storage_type())
           << "Storage type mismatch detected. " << stype << "(src) vs. "
           << buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
      for (size_t i = 1; i < src.size(); ++i) {
        CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); // 定义拷贝操作
        reduce[i] = buf.copy_buf[i-1];
        const_vars[i-1] = reduce[i].var(); // 定义拷贝操作的输入
      }

      Engine::Get()->PushAsync( // push该操作至engine,engine会根据输入来规划什么时候执行
        [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
          ReduceSumCPU(reduce);
          on_complete();
        }, Context::CPU(), const_vars, {reduce[0].var()},
        FnProperty::kCPUPrioritized, priority, "KVStoreReduce");

    }
  • Broadcast

CommDevice

  • Init:将存储类型和shape信息存储在sorted_key_attrs_中。
  • InitBuffersAndComm:将vector& src的context信息存储在devs中,通过InitMergeBuffersorted_key_attrs_信息,将所有的KVPairs分别存储在GPU上。
  • Reduce:和CommCPU的reduce一样,同样也是为了累积求和。

3.1.2 kvstore.h

kvstore.h中定义了几个纯虚函数

  • Init:根据参数定义的一组KVPairs初始化
  • Push:将一组KVPairs执行Push操作
  • Pull: Pull操作
  • Updater:用于参数的更新

3.1.3 local 和 device

在初始化KVStore我们需要提供KVStore的类型,在MXNet中提供了localdevice两种用于单机训练时的类型。不论哪种,都在文件kvstore_local.h中定义。两者最主要的区别就是对Comm的选择,local会使用CommCPU来进行comm_的初始化,device使用CommDevice来初始化。
KVStoreLocal有以下几个重要的方法:

  • Init:设置key的类型(str或者int),进行初始化。初始化的方法是使用comm_的初始化方法,同时还会在local_保存一个pinned_ctx_类型的拷贝。pinned_ctx_指的是不会被移出cache的内存。
  • Push:根据输入的KVPairs,使用comm_Reduce方法,进行相同keyvalue的求和。并且如果注册了updater_的话,会调用updater_进行更新。在进行更新之前,如果是在GPU端更新,会先将保存在local_的参数拷贝至GPU。
  • Pull:Pull方法主要的工作是将存储在local_的参数复制到对应的输出中。
    经过对这几个主要方法的理解,我们就清楚了KVStore的主要工作方式,也就对它对内存和显存的占用有了一个清晰的了解。具体的实现细节还是要参考源码去了解。

3.1.4 KVStoreDist

这篇博客的重点还是去试图了解分布式下的KVStore,当我们使用dist-*去create KVStore的时候,就会使用到类KVStoreDistKVStoreDist分两个主要部分,一个是worker,一个是server
如果该节点是worker,首先会创建一个ps_worker_ = new ps::KVWorker(0, new_customer_id);这个ps::KVWorker将在pslite部分具体解析,它是主要的完成pushpull操作的部分。
server的启动:在我们通过import mxnet的时候,会导入kvstore_server,而导入该文件会允许语句_init_kvstore_server_module(),阅读该函数源码不难发现,它会判断当前节点是否是server节点,如果是就会调用server.run(),然后调用c++代码的MXKVStoreRunServer,也就是类KVStoreDistRunServer方法,该方法会创建server_ = new KVStoreDistServer();

  • set_updater:updater的设置是通过python端的函数定义来完成的,它通过ctype转换成为了c端的函数,并且通过pickle序列化为字符串传递给server。
    当然,我们的主要注意力还是放在pushpull的实现上。
  • Push_:push操作首先会通过comm_进行Reduce操作,并将结果存储在comm_buf_[key]中,完成了本地的Reduce后,调用EncodeDefaultKey函数将存储为key : intval : NDArray形式的KVPair,转化为PSKV形式,该形式用于Push操作。之后会通过PushDefault方法完成操作,该方法定义了函数push_to_servers,将comm_buf_[key]作为输入,通过Engine::Get()->PushAsync方法完成push操作的异步执行(只是将任务发给Engine,由Engine完成调度)。Engine会在适当的时机执行push_to_servers,该函数调用ps_worker_ZPush方法来实现分布式的push
  • PullImpl:pull操作由该函数来完成,该函数会根据keysserver端的结果获取到对应的NDArray中。中间结果会保存在comm_buf_[key]中,这里由于之前push将该变量作为了输入,Engine在调度执行时会考虑到这点,保证所有对comm_buf_[key]的操作都在对它的读入完成之后,也就是push完成之后(push将它作为了输入)。类似于Push_操作,Pull操作定义了函数pull_from_servers作为异步执行的函数,调用PushAsync发送给Engine。pull_from_servers函数调用了ps_worker_ZPull方法来完成分布式的pull操作。

这里的分析只是简单的流程的总结,更多实现的细节可以通过阅读源码来了解。

3.1.5 KVStoreDistServer

如果当前节点是server,那么就会建立一个KVStoreDistServer对象,由该对象完成对workerpush,pull请求的处理。其中最重要的方法是DataHandleEx,它根据RequestType来调用相应的函数完成对数据的处理。
KVStoreDistServer的构造函数中,会执行ps_server_ = new ps::KVServer(0);它建立了一个ps::KVServer对象,该对象调用ps_server_->set_request_handle(std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));DataHandleEx绑定在自己的request_handle_上。

  • Run:前面提到过,如果该节点是server会调用RunServer方法,该方法就会调用if (server_) server_->Run();阅读KVStoreDistServer的源码发现,Run仅仅只有一行exec_.Start();。这一行会调用Executor exec_;Start方法,源码如下
void Start() {
    std::unique_lock lk(mu_);
    while (true) {
      cond_.wait(lk, [this]{return !queue_.empty();}); // queue_为空,则等待被唤醒
      Block blk = std::move(queue_.front()); // 取出queue头元素
      queue_.pop();
      lk.unlock(); // 释放锁,给其他线程操作queue

      if (blk.f) { // 如果blk定义了一个function,则允许他
        blk.f();
        blk.p->set_value(); // 返回function的结果
      } else {
        blk.p->set_value(); break;
      }
      lk.lock(); // 获取锁,执行下一个循环
    }

调用ExecutorExec方法,会在queue中添加一个执行函数的block,代码如下

void Exec(const Func& func) {
    Block blk(func); // 建立block
    auto fut = blk.p->get_future();
    {
      std::lock_guard lk(mu_);
      queue_.push(std::move(blk));
      cond_.notify_one(); // 通知别的线程运行
    }
    fut.wait();
  }

有了上面的知识,我们来看一下具体怎么处理数据。

  • DataHandleDefault:该方法是默认的数据处理的方法,由于DataHandleEx被绑定为了数据的处理函数,当RequestTypekDefaultPushPull,就会调用该函数。它会根据传入的信息,提取对应的key,将对应的数据存储在store_[key]。如果从worker来的request类型是push,就会分两种情况运行。一种是初始化的时候,由于初始化同样通过调用push来完成,因此初始化的push只会将store_[key]设置为对应的值。另一种是初始化后,每一次的push都会进行相应的操作。这里每一次从任何一个worker来的某一个keypush操作,都会存储在updates.merged中,并且除了第一次的push,之后的push会进行updates.merged += updates.temp_array;也就是和之前的push相加。并且ApplyUpdates只会在push数达到worker的个数的时候,才会真正地进行。也只有在ApplyUpdates真正执行的时候才会将回复返回给worker。这样,就实现了同步。

对于server的讲解,这里也只是简单地描述它的同步和执行的简单机制,具体更多的实现细节,可以参考源码来了解。

3.2 pslite

通过前面的了解,我们知道了worker会使用ps_worker_ZPush方法来完成push操作,使用ZPull方法来完成pull操作。类似地,server会使用ps_server_request_handle_来进行数据处理的传递,使用SimpleApprequest_handle_来完成Command处理的传递。这一部分,我们就来了解它们的实现。
KVWorkerKVServer都定义在文件kv_app.h中,它们都继承自SimpleApp

3.2.1 kv_app

kv_app是MxNet主要应用的部分。ps-lite实现了两个app,一个是simple_app,一个是kv_app

KVWorker

当数据从MXNet端的push函数传递到parameter server端时,调用了如下方法:

int ZPush(const SArray& keys,
            const SArray& vals,
            const SArray& lens = {},
            int cmd = 0,
            const Callback& cb = nullptr) {
    int ts = obj_->NewRequest(kServerGroup);
    AddCallback(ts, cb);
    KVPairs kvs;
    kvs.keys = keys;
    kvs.vals = vals;
    kvs.lens = lens;
    Send(ts, true, cmd, kvs);
    return ts;
  }

该方法将kServerGroup作为数据传输对象,建立了KVPairs,通过Send方法,将数据发送给server。Send方法完成了数据从KVParisMessage的转换,然后调用Postoffice::Get()->van()->Send(msg);来执行数据的发送。

相应地,在执行pull操作的时候,调用了Pull_方法,该方法首先定义了一个回调函数,该函数在完成pull操作后执行,具体来说就是当发出的请求都得到了回应后,会在Process方法中执行下列函数:

// finished, run callbacks
  if (obj_->NumResponse(ts) == Postoffice::Get()->num_servers() - 1)  {
    RunCallback(ts);
  }

KVServer

前面说到过,KVServer会使用request_handle_来调用KVStore的数据处理函数。KVServer会在方法KVServer::Process中调用request_handle_,在这之前它会将得到的Message转换为KVMetaKVPairs。这样就完成了数据从接收到,再到传递给MXNet端的数据处理函数的过程。

由于时间有限,内容较多,就不一一介绍函数。

3.2.2 postoffice.cc

Postoffice是一个类似于全局管理者的角色,它完成了环境初始化等必要工作

3.2.3 van

从前面的介绍我们看到,所有的数据在发送的最后,调用的都是van的send方法。van的具体实现类是ZMQVan。由于本人对于zmq也只是个初学者,这里有兴趣的同学可以去详细了解它的实现以及性能。

3.2.4 meta.proto

zmq在进行数据传输的时候,会建立socket,并且将字符串传递给对应的对象。在代码中,使用了protobuf来进行数据到字符串的转换工作。

3.2.4 SArray.h

SArray全名Shared array,它完成了在进行数据赋值过程中的零拷贝,及时是不同类型间数据的赋值,仅仅是将数据指向的指针进行赋值,同时将类型进行保存而已。

3.2.5 message.h

后记

今天已经很晚了,只能在pslite部分草草收尾,希望下次进行总结的时候能够做的更好。
总体来说,MXNet对于我这样一个初学者来说有很多可以学习的地方,并且它异步的实现和parameter server的设计,都是非常值得学习的内容。

你可能感兴趣的:(MxNet源码解析(1) KVStore,pslite源码解析)