MXNet: Barrier

1. KVStore里的Barrier

在mxnet的分布式训练里,主要模式就是参数服务器。每个worker或者agent就是一台machine,server用于参数的更新。那么,当我们期望在不同的worker之间进行同步的时候,就会需要到barrier这个方法。
当代码运行在worker的时候,我们可以通过调用kv._barrier()来进行同步。它的作用就是,会阻塞代码运行,直到每个worker都运行了kv._barrier()。然后接着运行。这样就实现了同步。
那么它是怎么做到的呢?

通过源码,我们不难发现,python端的接口调用了c++端的方法:

void Barrier() override {
    ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
}

这个全局的PostofficeBarrier方法的部分源码如下:

void Postoffice::Barrier(int customer_id, int node_group) {
  // 省略部分代码
  // 省略部分代码


  std::unique_lock ulk(barrier_mu_);
  barrier_done_[0][customer_id] = false;
  Message req;
  req.meta.recver = kScheduler;
  req.meta.request = true;
  req.meta.control.cmd = Control::BARRIER;
  req.meta.app_id = 0;
  req.meta.customer_id = customer_id;
  req.meta.control.barrier_group = node_group;
  req.meta.timestamp = van_->GetTimestamp();
  CHECK_GT(van_->Send(req), 0);
  barrier_cond_.wait(ulk, [this, customer_id] {
      return barrier_done_[0][customer_id];
    });
}

可以看到该方法会首先对barrier_mu_上锁,之后将对应的barrier_done_设置为false。然后将这次的barrier信息发送给scheduler。告诉scheduler需要进行一次barrier。然后就阻塞等待barrier_done_被设置为true,代表完成了barrier,也就是其他的worker也都进行了barrier。

那么问题就变成了,每个worker都是怎么直到其他worker也进行了barrier的?

首先我们要知道,在参数服务器也就是PS中,每个进程都会建立kvstore。如果是worker,会在构造函数中运行如下代码:

if (IsWorkerNode()) {
      int new_customer_id = GetNewCustomerId();
      ps_worker_ = new ps::KVWorker(0, new_customer_id);
      ps::StartAsync(new_customer_id, "mxnet\0");
      if (!ps::Postoffice::Get()->is_recovery()) {
        ps::Postoffice::Get()->Barrier(
          new_customer_id,
          ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
      }
    }

其中ps::StartAsync如下:

inline void StartAsync(int customer_id, const char* argv0 = nullptr) {
  Postoffice::Get()->Start(customer_id, argv0, false);
}

也就是说,worker在建立起ps_worker_后,开始运行postoffice,而postoffice的Start会进行一系列的操作,并调用van_->Start,接着vanStart会进行一系列的初始化后,开启接受消息的线程,也就是

receiver_thread_ = std::unique_ptr(
            new std::thread(&Van::Receiving, this));

receiving函数会使用ProcessBarrierCommand处理barrier信号,该函数会++barrier_count_[group],也就是将对应group的barrier次数进行统计。当barrier_count_[group]等于这个group的个数的时候。它会发送类似于ACK的返回信息。

然后worker会调用Manage方法来处理该message。Manage发现是barrier的返回信息,将barrier_done_设置为true,然后将等待的线程唤醒。也就是python端调用barrier后被阻塞的地方。

至此,就完成了一次worker之间的barrier。

你可能感兴趣的:(MXNet: Barrier)