Paramter Server

Paramter Server

​ Author:lyp@ Date:2017/01/26

在深度神经网络计算框架中,参数服务器是一个非常重要的基础概念,而其不同的实现对计算效果和计算能力都有直接的影响。

请你自学参数服务器的概念,并给出一个综述,介绍什么是参数服务器,它对机器学习的作用是什么,一般实现有哪些方案,各自又有哪些优缺点。

背景

1. 问题的提出?

​ 在大规模数据上跑机器学习任务是过去十多年内系统架构师面临的主要挑战之一,许多模型和抽象先后用于这一任务。

​ 现在的大数据机器学习系统,通常数据在1TB到1PB之间,参数范围在10^9 和10^12左右。而往往这些模型的参数需要被所有的worker节点频繁的访问,这就会带来很多问题和挑战:

  1. 访问这些巨量的参数,需要大量的网络带宽支持;
  2. 很多机器学习算法都是连续型的,只有上一次迭代完成(各个worker都完成)之后,才能进行下一次迭代,这就导致了如果机器之间性能差距大(木桶理论),就会造成性能的极大损失;
  3. 在分布式中,容错能力是非常重要的。很多情况下,算法都是部署到云环境中的(这种环境下,机器是不可靠的,并且job也是有可能被抢占的);

2. 业内如何解决?

​ 如何解决这些问题呢?对于机器学习分布式优化,有很多大公司在做了,包括:Amazon,Baidu,Facebook,Google,Microsoft 和 Yahoo。也有一些开源的项目,比如:YahooLDA 和 Petuum 和Graphlab。

​ 从最开始的MPI,到Hadoop,Spark 以及Paramter Server。都曾广泛应用于机器学习处理任务。总结一下:

  • ==MPI Gradient Aggregation==:批任务求解器的速度不高,无法支持大规模数据集。
  • ==MapReduce==:解决了MPI无法支撑大数据的问题。但无法改进批处理求解器的训练性能,并且还引入了新的问题,包括迭代式计算的低效,节点之间通信低效。
  • ==GraphLab==:基于图的抽象。用图来做抽象可以解决许多机器学习问题,但仍然有许多问题无法很好高效求解,比如深度学习中的多层结构。
  • ==Parameter Server==:跟基于图的方法主要区别在于把模型参数存储和更新上升为主要组件,并且采用了异步机制提升处理能力。

Paramter Server发展历程

​ 参数服务器也经历了多次发展。

  • 第一代参数服务器

    ​ 参数服务器的概念最早来自Alex Smola于2010年提出的并行LDA的框架[4]。它通过采用一个分布式的Memcached作为存放参数的存储,这样就提供了有效的机制用于在分布式系统不同的Worker节点之间同步模型参数,而每个Worker只需要保存它计算时所依赖的一小部分参数即可。当然,这里存放参数的存储跟做OLTP应用中的Key-Value抽象有所不同,因为以Key-Value为单元进行频繁的参数数据交互会导致过高的通信开销,因此参数服务器通常采用数学封装来进行参数同步,比如向量,张量,矩阵的行列等。

  • 第二代参数服务器

    ​ 对第二代分布式架构做改进的初步尝试是Petuun,他使用一个有限制的延迟模型,同时在工作线程模型上添加更多限制。

  • 第三代参数服务器

    ​ 来自Alex Smola的高徒——李沐设计的参数服务器。ps-lite应当属于第三代参数服务器,提供了更加通用的设计。==以下也主要介绍他在论文中提供的ps架构==。

Paramter Server架构设计

1. Paramter Server 整体架构

PS架构主要包括两大部分。那就是一个参数服务器组server group 和多个工作组。在parameter server中,每个 server 实际上都只负责分到的部分参数(servers共同维持一个全局的共享参数),而每个 work 也只分到部分数据和处理任务;

Paramter Server_第1张图片
image

一些概念解释:

  • server 节点可以跟其他 server 节点通信,每个server负责自己分到的参数,server group 共同维持所有参数的更新。
  • server manager node 负责维护一些元数据的一致性,比如各个节点的状态,参数的分配情况等;
  • worker 节点之间没有通信,只跟自己对应的server进行通信。
  • 每个worker group有一个task scheduler,负责向worker分配任务,并且监控worker的运行情况。当有新的worker加入或者退出,task scheduler 负责重新分配任务。
  • training data 被split多个部分,一个worker在本地将一部分训练数据存储在本地统计数据中。

2. Paramter Server通信设计

  • ==(key,value)==

    ​ parameter server 中,参数都是可以被表示成(key, value)的集合,比如一个最小化损失函数的问题,key就是feature ID,而value就是它的权值。对于稀疏参数,不存在的key,就可以认为是0.


    Paramter Server_第2张图片
    image
  • ==(key,value)Vectors==

    ​ 如果每一个参数都设一个key,那么会使得通信变得非常频繁低效,为了抹平这个问题,赋予每个key所对应的value向量概念或者矩阵概念。做这样的操作的前提是假设参数是有顺序的。

    ​ 这样做有两点好处,降低网络通信和使得向量层面的操作变得可行,从而很多线性库的优化特性可以利用的上,比如BLAS、LAPACK、ATLAS等。缺点是在对于稀疏模型来说,总会在向量或者矩阵里会有参数为0,这在单个参数状态下是不用存的,所以,造成了数据的冗余。

  • ==range push and pull==

    ​ workers 跟 servers 之间通过 pushpull 来通信。worker 通过 push 将计算好的梯度发送到server,然后通过 pull 从server更新参数。为了提高计算性能和带宽效率,parameter server 允许用户使用 Range PushRange Pull 操作;

    ​ 假设 R 是需要push或pull的 key 的range,那么可以进行如下操作。就是发送和接送特定Range中的w。

    w.push(R, dest)
    w.pull(R, dest)
    
Paramter Server_第3张图片
image

- ==Asynchronous Tasks and Dependency & Flexible Consistency==

​ 体会一下Asynchronous Task 跟 Synchronous Task 的区别。

​ 如果 iter12 需要在 iter11 computation,push 跟 pull 都完成后才能开始,那么就是Synchronous,反之就是Asynchronous.如iter 11 在 iter10计算完成后就开始执行。

![image](http://upload-images.jianshu.io/upload_images/2472711-92152ae529e7940b.jpg?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)


​ 参数服务器和工作节点之间的通信都属于远程调用,那么,远程调用是比较耗时的行为,如果每次都保持同步的话,那么训练相对于单节点来说是减慢了许多的,因为远程调用的耗时。因而,PS框架让远程调用成为一部调用,比如参数的push和pull发出之后,立即使用当前值开始进行下一步的梯度计算,如上图,迭代11发出push和pull的请求后,立马开始进行梯度计算,而此时,使用的还是迭代10的值。

​ **Asynchronous Task**:能够提高系统的效率(因为节省了很多等待的过程),但是,它的缺点就是容易降低算法的收敛速率;

​ 所以,系统性能跟算法收敛速率之间是存在一个trade-off的,你需要同时考虑:

```xml
算法对于参数非一致性的敏感度;
训练数据特征之间的关联度;
硬盘的存储容量;

​ 考虑到用户使用的时候会有不同的情况,parameter server 为用户提供了多种任务依赖方式:


Paramter Server_第4张图片
image
  • Sequential: 这里其实是 synchronous task,任务之间是有顺序的,只有上一个任务完成,才能开始下一个任务;

  • Eventual: 跟 sequential 相反,所有任务之间没有顺序,各自独立完成自己的任务,

  • Bounded Delay: 这是sequential 跟 eventual 之间的trade-off,可以设置一个 τ 作为最大的延时时间。也就是说,只有 >τ 之前的任务都被完成了,才能开始一个新的任务;极端的情况:

    • τ=0,情况就是 Sequential;
    • τ=∞,情况就是 Eventual;
  • ==User-defined Filters==

    ​ 作为上述特点的补充,PS还有这样一个小feature,即过滤,在工作节点这一端对梯度进行过滤,如果梯度并不是那么影响重大,就不用占用网络去更新。

Paramter Server架构实现

1. Vector Clock

​ 为参数服务器中的每个参数添加一个时间戳,来跟踪参数的更新和防止重复发送数据。基于此,通信中的梯度更新数据中也应该有时间戳,防止重复更新。

​ 如果每个参数都有一个时间戳,那么参数众多,时间戳也众多。好在,parameter server 在push跟pull的时候,都是rang-based,这就带来了一个好处:这个range里面的参数共享的是同一个时间戳,这显然可以大大降低了空间复杂度。

2. Messages

​ Message是节点间交互的主要格式。一条 message 包括:时间戳,len(range)对k-v.

​ $[vc(R),(k1,v1),...,(kp,vp)]kj∈Randj∈{1,...p}$

​ 这是parameter server 中最基本的通信格式,不仅仅是共享的参数才有,task 的message也是这样的格式,只要把这里的(key, value) 改成 (task ID, 参数/返回值)。

​ Messages may carry a subset of all available keys within range R. The missing keys are assigned the same timestamp without changing their values.

​ 由于机器学习问题通常都需要很高的网络带宽,因此信息的压缩是必须的。

  • key的压缩: 因为训练数据通常在分配之后都不会发生改变,因此worker没有必要每次都发送相同的key,只需要接收方在第一次接收的时候缓存起来就行了。第二次,worker不再需要同时发送key和value,只需要发送value 和 key list的hash就行。这样瞬间减少了一般的通信量。
  • value的压缩: 假设参数时稀疏的,那么就会有大量的0存在。因此,为了进一步压缩,我们只需要发送非0值。parameter server使用 Snappy 快速压缩库来压缩数据、高效去除0值。
  • key 的压缩和 value 的压缩可以同时进行。

3. Replication and Consistency

​ parameter server 在数据一致性上,使用的是传统的一致性哈希算法,参数key与server node id被插入到一个hash ring中。具体实现可以参考另一篇blog一致性hash算法详解。动态增加和移除节点的同时还能保证系统存储与key分配的性能效率.

Paramter Server_第5张图片
image

​ 两种方式保证slave跟master之间的数据一致性:

  1. 默认的复制方式: Chain replication (强一致性, 可靠):

    Paramter Server_第6张图片
    image

a. 更新:只能发生在数据头节点,然后更新逐步后移,直到更新到达尾节点,并由尾节点向客户确认更新成功;
b. 查询:为保证强一致性,客户查询只能在尾节点进行;

  1. Replication after Aggregation
    Paramter Server_第7张图片
    image

两个worker 节点分别向server传送x和y。server 首先通过一定方式(如:f(x+y) )进行aggregate,然后再进行复制操作;

当有n个worker的时候,复制只需要k/n的带宽。通常来说,k(复制次数)是一个很小的常数,而n的值大概是几百到几千;

4. Server Management

由于key的range特性,当参数服务器集群中增加一个节点时,步骤如下:

  • server manager节点给新节点分配一个key range,这可能会导致其他节点上的key range切分
  • 新节点从其他节点上将属于它的key range数据取过来,然后也将slave信息取过来
  • server manager广播节点变动,其他节点得知消息后将不属于自己key range的数据删掉

​ 在第二步,从其他节点上取数据的时候,其他节点上的操作也分为两步,第一是拷贝数据,这可能也会导致key range的切分。第二是不再接受和这些数据有关的消息,而是进行转发,转发到新节点。

​ 在第三步,收到广播信息后,节点会删除对应区间的数据,然后,扫描所有的和R有关发送出去的还没收到回复的消息,当这些消息回复时,转发到新节点。

​ 节点的离开与节点的加入类似。

5. Worker Management

添加工作节点比添加服务器节点要简单一些,步骤如下:

  • task scheduler给新节点分配一些数据
  • 节点从网络文件系统中载入数据,然后从服务器端拉取参数
  • task scheduler广播变化,其他节点free掉一些训练数据

​ 当一个节点离开的时候,task scheduler可能会寻找一个替代,但恢复节点是十分耗时的工作,同时,损失一些数据对最后的结果可能影响并不是很大。所以,系统会让用户进行选择,是恢复节点还是不做处理。这种机制甚至可以允许用户删掉跑的最慢的节点来提升速度。

PS-lite 实现

PS-Lite是PS架构的一个轻量级的实现。它提供了push,pull,wait等APIs。整个项目代码量不多。

A light and efficient implementation of the parameter server framework. It provides clean yet powerful APIs. For example, a worker node can communicate with the server nodes by

  • Push(keys, values): push a list of (key, value) pairs to the server nodes
  • Pull(keys): pull the values from servers for a list of keys
  • Wait: wait untill a push or pull finished.

A simple example:

  std::vector key = {1, 3, 5};
  std::vector val = {1, 1, 1};
  std::vector recv_val;
  ps::KVWorker w;
  w.Wait(w.Push(key, val));
  w.Wait(w.Pull(key, &recv_val));

总体概览

整个项目的类图如下:

Paramter Server_第8张图片
image
  • Postoffice是全局管理类,单例模式创建。管理当前节点角色、其他节点的连接、心跳信息、配置信息。
  • Van是负责通信的类,是Postoffice的成员。Van中std::unordered_map senders_保存了node_id到连接的映射。Van只是定义了接口,具体实现是依赖ZMQ实现的ZMQVan
  • Customer用来通信,跟踪request和response。每一个连接对应一个Customer实例,连接对方的id和Customer实例的id相同。
  • SimpleApp是一个基类;提供了发送接收int型的head和string型的body消息,以及注册消息处理函数。它有2个派生类。
  • KVServer是SimpleApp的派生类,用来保存key-values数据。
  • KVWorker是SimpleApp的派生类,用来想Server Push/Pull key-value数据。
  • KVPairs封装了Key-Value结构,还包含了一个长度选项。
  • SArray是Shared array,像智能指针一样共享数据,接口类似vector。
  • Node封装了节点的信息,例如角色、ip、端口、是否是恢复节点。
  • Control封装了控制信息,例如命令类型、目的节点、barrier_group的id、签名。
  • Meta封装了元数据,发送者、接受者、时间戳、请求还是相应等。
  • Message是要发送的信息,除了元数据外,还包括发送的数据。

节点角色ID

Paramter Server_第9张图片
image

一共三种类型,从上图可以看出Scheduler节点只有一个,多个Worker和多个Server可以组成一个Group,因此有WorkerGroup和ServerGroup;还有Worker节点和Server节点。每个节点以及每一个Group都有唯一确定的ID。
Scheduler、ServerGroup、WorkerGroup节点ID确定如下:

/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
 * \brief the server node group ID
 *
 * group id can be combined:
 * - kServerGroup + kScheduler means all server nodes and the scheuduler
 * - kServerGroup + kWorkerGroup means all server and worker nodes
 */
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;

上述定义在base.h中。

1、2、4的二进制表示分别为:001、010、001。这样可以做Group之间的合并,例如要和ServerGroup和WorkerGroup发信息,只需要destination node id设为2+4=6。
1-7用来表示节点的组合。单个节点的ID从8开始。单个Server和单个Worker节点从自己的rank(0、1、2……)转换到其ID:

/**
   * \brief convert from a worker rank into a node id
   * \param rank the worker rank
   */
  static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
  }
  /**
   * \brief convert from a server rank into a node id
   * \param rank the server rank
   */
  static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
  }

ID到其rank转换:

  static inline int IDtoRank(int id) {
    return std::max((id - 8) / 2, 0);
  }

Postofficetd::unordered_map> node_ids_==保存了Node/NodeGroup与连接节点集合的对应关系==。

消息封装

Paramter Server_第10张图片
zz
  • 首先使用了自定义的SArray,Smart Array。共享数据,减少数据拷贝,且提供了类似vector的接口。==备注:没有仔细看实现,在这里先把他理解成一个数组。==

  • 元数据Meta使用了Protobuf,进行了数据压缩.具体使用可以参考blog14,链接见参考文献。具体定义见代码。

  • 消息分层比较清晰。Node包含节点的角色、id、ip、端口信息;Control包含了命令信息、签名等;Meta是元数据,包含时间戳、发送者、接受者、控制信息等;Message才是发送的信息,包含元数据和发送的数据。他们之间的构造见上图。

  • 参数有key-value组成,对应于KVPairs。定义如下

    struct KVPairs {
      // /** \brief empty constructor */
      // KVPairs() {}
      /** \brief the list of keys */
      SArray keys;
      /** \brief the according values */
      SArray vals;
      /** \brief the according value lengths (could be empty) */
      SArray lens;
    };
    

通信机制

Scheduler节点管理所有节点的地址。每个节点要知道Scheduler节点的IP、port;启动时绑定一个本地端口,并向Scheduler节点报告。==Scheduler节点在每个几点启动后,给节点分配ID,把节点信息通知出去(例如Worker节点要知道Server节点IP和端口,Server节点要知道Worker节点的IP和端口)==。节点在建立连接后,才会正式启动。

测试链接的过程:

  • test.connection.cc:执行ps:start函数

    int main(int argc, char *argv[]) {
      ps::Start(0);
      // do nothing
      ps::Finalize(0, true);
      return 0;
    }
    
    • Start,获取一个postoffice的单利对象,并调用start方法。

      inline void Start(int customer_id, const char* argv0 = nullptr) {
        Postoffice::Get()->Start(customer_id, argv0, true);
      }
      
      • postoffice的start函数,是主要内容。分三步,1.初始化node,2.start van 3.如果设置barrier,start barrier。

        void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
          ....
            //1. init node info. 多少个worker。
            for (int i = 0; i < num_workers_; ++i) {
              int id = WorkerRankToID(i); //i*2+9 获得workid
              for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                            kWorkerGroup + kScheduler,
                            kWorkerGroup + kServerGroup + kScheduler}) {
                node_ids_[g].push_back(id);
              }
            }
            for (int i = 0; i < num_servers_; ++i) {
              int id = ServerRankToID(i);
              for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
                            kServerGroup + kScheduler,
                            kWorkerGroup + kServerGroup + kScheduler}) {
                node_ids_[g].push_back(id);
              }
            }
            for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                          kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
              node_ids_[g].push_back(kScheduler);
            }
            init_stage_++;
          }
           .....
          //2. start van
          //已经初始化完node,这边启动通信。
          van_->Start(customer_id);
           ......
          //3. do a barrier here
          if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
        }
        
        • 此刻,我们查看van_->Start(customer_id);方法。int init_stage = 0;这个是van初始化最开始的值。

          void Van::Start(int customer_id) {
            // get scheduler info
            start_mu_.lock();
            if (init_stage == 0) {
              scheduler_.hostname = std::string(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI")));
              scheduler_.port = atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT")));
              scheduler_.role = Node::SCHEDULER;
              scheduler_.id = kScheduler;
              is_scheduler_ = Postoffice::Get()->is_scheduler();
          
              //1. get my node info 获取节点信息,主要是ip port role 等信息
              if (is_scheduler_) {
                my_node_ = scheduler_;
              } else {
                auto role = is_scheduler_ ? Node::SCHEDULER :
                            (Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER);
                const char *nhost = Environment::Get()->find("DMLC_NODE_HOST");
                std::string ip;
                if (nhost) ip = std::string(nhost);
                if (ip.empty()) {
                  const char *itf = Environment::Get()->find("DMLC_INTERFACE");
                  std::string interface;
                  if (itf) interface = std::string(itf);
                  if (interface.size()) {
                    GetIP(interface, &ip);
                  } else {
                    GetAvailableInterfaceAndIP(&interface, &ip);
                  }
                  CHECK(!interface.empty()) << "failed to get the interface";
                }
                int port = GetAvailablePort();
                const char *pstr = Environment::Get()->find("PORT");
                if (pstr) port = atoi(pstr);
                CHECK(!ip.empty()) << "failed to get ip";
                CHECK(port) << "failed to get a port";
                my_node_.hostname = ip;
                my_node_.role = role;
                my_node_.port = port;
                // cannot determine my id now, the scheduler will assign it later
                // set it explicitly to make re-register within a same process possible
                my_node_.id = Node::kEmpty;
                my_node_.customer_id = customer_id;
              }
          
              //2. bind. 绑定端口
              my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40);
              PS_VLOG(1) << "Bind to " << my_node_.DebugString();
              CHECK_NE(my_node_.port, -1) << "bind failed";
          
              //3. connect to the scheduler 建立连接。具体实现在 zmq中。
              Connect(scheduler_);
          
              //4. for debug use
              if (Environment::Get()->find("PS_DROP_MSG")) {
                drop_rate_ = atoi(Environment::Get()->find("PS_DROP_MSG"));
              }
              //5. start receiver 开一个新的线程。来接受信息。
              receiver_thread_ = std::unique_ptr(
                      new std::thread(&Van::Receiving, this));
              init_stage++;
            }
            start_mu_.unlock();
          
            if (!is_scheduler_) {
              // let the scheduler know myself 如果不是scheduler,发送message。告知自己。
              Message msg;
              Node customer_specific_node = my_node_;
              customer_specific_node.customer_id = customer_id;
              msg.meta.recver = kScheduler;
              msg.meta.control.cmd = Control::ADD_NODE;
              msg.meta.control.node.push_back(customer_specific_node);
              msg.meta.timestamp = timestamp_++;
              Send(msg);
            }
            // wait until ready
            while (!ready_.load()) {
              std::this_thread::sleep_for(std::chrono::milliseconds(100));
            }
          
            start_mu_.lock();
            if (init_stage == 1) {
              // resender
              if (Environment::Get()->find("PS_RESEND") && atoi(Environment::Get()->find("PS_RESEND")) != 0) {
                int timeout = 1000;
                if (Environment::Get()->find("PS_RESEND_TIMEOUT")) {
                  timeout = atoi(Environment::Get()->find("PS_RESEND_TIMEOUT"));
                }
                resender_ = new Resender(timeout, 10, this);
              }
          
              if (!is_scheduler_) {
                // start heartbeat thread 开启一个新的县城 进行心跳检测。
                heartbeat_thread_ = std::unique_ptr(
                        new std::thread(&Van::Heartbeat, this));
              }
              init_stage++;
            }
            start_mu_.unlock();
          }
          
          

至此,通信连接建立完成。

同步策略

异步工作时,Worker计算参数可能要依赖前面Pull是否完成。如果需要等待某一步操作,可以调用SimpleApp::Wait操作。具体实现是调用了Customer::WaitRequest(),它会跟踪request和response数量是否相同,直到相同才会返回;tracker_类型为std::vector>,记录了request和response数量,这个数据结构一直增长,会造成内存一直增长。

消息处理流程

每个节点都监听了本地一个端口;该连接的节点在启动时已经连接。 上述通信机制的时候已经描述过。

回顾一下通信机制中VAN start()方法的内容:

  • 获取scheduler信息;
  • 获取node 信息;绑定端口。
  • 连接scheduler节点。Connect(scheduler__)
  • ==开启一个线程,来接受信息==。receiver_thread_ = std::unique_ptr< std::thread>( new std::thread(&Van::Receiving, this));
  • 如果不是scheduler节点,发送messge,告诉scheduler节点。
  • 开一个线程,来发送心跳。

而针对消息处理流程,主要的逻辑集中在上述标黄那一步开始的。

对于==Server节点==:

  1. Van::Receiving()函数是单独一个线程来接收数据。数据接收后,根据不同命令执行不同动作,例如Control::ADD_NODE就是添加节点。如果需要下一步处理,会将消息传递给Customer::Accept函数。

    void Van::Receiving() {
      Meta nodes;
      Meta recovery_nodes;  // store recovery nodes
      recovery_nodes.control.cmd = Control::ADD_NODE;
    
      while (true) {
        Message msg;
        int recv_bytes = RecvMsg(&msg);
        // For debug, drop received message
        if (ready_.load() && drop_rate_ > 0) {
          unsigned seed = time(NULL) + my_node_.id;
          if (rand_r(&seed) % 100 < drop_rate_) {
            LOG(WARNING) << "Drop message " << msg.DebugString();
            continue;
          }
        }
    
        CHECK_NE(recv_bytes, -1);
        recv_bytes_ += recv_bytes;
        if (Postoffice::Get()->verbose() >= 2) {
          PS_VLOG(2) << msg.DebugString();
        }
        // duplicated message
        if (resender_ && resender_->AddIncomming(msg)) continue;
    
        if (!msg.meta.control.empty()) {
          // control msg
          auto& ctrl = msg.meta.control;
          if (ctrl.cmd == Control::TERMINATE) {
            ProcessTerminateCommand();
            break;
          } else if (ctrl.cmd == Control::ADD_NODE) {
            ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes);
          } else if (ctrl.cmd == Control::BARRIER) {
            ProcessBarrierCommand(&msg);
          } else if (ctrl.cmd == Control::HEARTBEAT) {
            ProcessHearbeat(&msg);
          } else {
            LOG(WARNING) << "Drop unknown typed message " << msg.DebugString();
          }
        } else {
          ProcessDataMsg(&msg);
        }
      }
    }
    

  2. Customer::Accept()函数将消息添加到一个队列recv_queue_Customer::Receiving()是一个线程在运行,从队列取消息处理;处理过程中会使用函数对象recv_handle_处理消息,这个函数对象是SimpleApp::Process函数。

    void Customer::Receiving() {
      while (true) {
        Message recv;
        recv_queue_.WaitAndPop(&recv);
        if (!recv.meta.control.empty() &&
            recv.meta.control.cmd == Control::TERMINATE) {
          break;
        }
        //该线程处理消息
        recv_handle_(recv);
        if (!recv.meta.request) {
          std::lock_guard lk(tracker_mu_);
          tracker_[recv.meta.timestamp].second++;
          tracker_cond_.notify_all();
        }
      }
    }
    
  3. SimpleApp::Process根据是消息类型(请求or响应,调用用户注册的函数来处理消息,request_handle_response_handle_分别处理请求和响应。

对于Worker节点,上面第3点略有不同。因为Worker都是通过PushPull来通信,而且参数都是key-value对。Pull·参数时,通过KVWorker::Process调用回调函数来处理消息。

调试及启动流程

PS Lite通过环境变量和外界交互。

启动流程:
1、首先启动Scheduler节点。这是要固定好Server和Worker数量。
2、启动Worker或Server节点。启动时连接Scheduler节点,绑定本地端口,并向Scheduler节点注册自己信息。
3、Scheduler等待所有Worker节点都注册后,给其分配id,并把节点信息传送出去。此时Scheduler节点已经准备好。
4、Worker或Server接收到Scheduler传送的信息后,建立对应节点的连接。此时Worker或Server已经准备好。

调试时,通过环境变量来控制调试日志。
PS_VERBOSE=1,会打印连接日志。
PS_VERBOSE=2,会打印所有数据通信日志。

源码test中连接事例

#include "ps/ps.h"
using namespace ps;

void StartServer() {
  if (!IsServer()) return;
  auto server = new KVServer(0);
  //设置kv默认处理handle, 可以自定义
  server->set_request_handle(KVServerDefaultHandle());
  RegisterExitCallback([server](){ delete server; });
}

void RunWorker() {
  if (!IsWorker()) return;
  KVWorker kv(0, 0);

  // init
  int num = 10000;
  std::vector keys(num);
  std::vector vals(num);

  int rank = MyRank();
  srand(rank + 7);
  for (int i = 0; i < num; ++i) {
    keys[i] = kMaxKey / num * i + rank;
    vals[i] = (rand() % 1000);
  }

  // push
  int repeat = 50;
  std::vector ts;
  for (int i = 0; i < repeat; ++i) {
    ts.push_back(kv.Push(keys, vals));

    // to avoid too frequency push, which leads huge memory usage
    if (i > 10) kv.Wait(ts[ts.size()-10]);
  }
  for (int t : ts) kv.Wait(t);

  // pull
  std::vector rets;
  kv.Wait(kv.Pull(keys, &rets));

  float res = 0;
  for (int i = 0; i < num; ++i) {
    res += fabs(rets[i] - vals[i] * repeat);
  }
  CHECK_LT(res / repeat, 1e-5);
  LL << "error: " << res / repeat;
}

int main(int argc, char *argv[]) {
  // setup server nodes
  StartServer();
  // start system
  Start(0);
  // run worker nodes
  RunWorker();
  // stop system
  Finalize(0, true);
  return 0;
}

其他

该部分内容暂未完成,进行学习。

PaddlePaddle

//@TODO

Tensorflow

//@TODO

Adam

//@TODO

Adam框架仍然基于Multi-Spert架构,这个架构的大体含义就是将集群分为如下几个部分:

  1. 数据服务类。存储数据,数据备份。向计算节点提供数据。
  2. 训练模型类。训练模型,然后更新参数。
  3. 参数服务器。维护一个共享的模型,计算节点计算完成后,可以向参数服务器发送请求更新参数。

参考文献

参考文献

  1. Scaling Distributed Machine Learning with the Parameter Server
  2. Parameter Server for Distributed Machine Learning
  3. PS-Lite Documents

参考Blog

  1. MPI 在大规模机器学习领域的前景如何
  2. 参数服务器——分布式机器学习的新杀器
  3. Allreduce (or MPI) vs. Parameter server approaches
  4. 横向对比三大分布式机器学习平台:Spark、PMLS、TensorFlow
  5. 机器学习入门:线性回归及梯度下降
  6. 详解并行逻辑回归
  7. 一致性HASH算法详解
  8. 【深度学习&分布式】Parameter Server 详解
  9. parameter_server架构
  10. 【分布式计算】MapReduce的替代者-Parameter Server
  11. Adam:大规模分布式机器学习框架
  12. ParameterServer入门和理解
  13. PS-Lite源码分析
  14. Google Protocol Buffer 的使用和原理
  15. 几种机器学习框架的对比和选择
  16. tensorflow架构
  17. 如何评价百度开源的深度学习框架 PaddlePaddle?

参考项目

  1. https://github.com/dmlc/ps-lite

你可能感兴趣的:(Paramter Server)