PS-Lite源码分析

    • Parameter Server框架
    • PS-Lite实现
      • 总体概览
      • 节点角色ID
      • 消息封装
      • 通信机制
      • 同步策略
      • 消息处理流程
      • 调试及启动流程
      • 一个例子
    • 参考

Parameter Server中文名称叫做参数服务器,是分布式机器学习框架中用来做参数同步的框架。具体介绍可以参考后面链接,这里主要学习一下其实现。

ps-lite是Paramter Server的实现的一个框架,其中参数处理具体相关策略需自己实现。

Parameter Server框架

Parameter Server包含三种角色:Worker、Server、Scheduler。具体关系如下图:

  • Worker节点负责计算参数,并发参数push到Server,同时从Serverpull参数回来。
  • Server节点负责管理Worker节点发送来的参数,并“合并”,之后供各个Worker使用。
  • Scheduler几点负责管理Worker节点和Server节点的状态。

PS-Lite实现

总体概览

简单看一下各个类以及它们之间的关系

  • Postoffice是全局管理类,单例模式创建。管理当前节点角色、其他节点的连接、心跳信息、配置信息。
  • Van是负责通信的类,是Postoffice的成员。Van中std::unordered_map senders_保存了node_id到连接的映射。Van只是定义了接口,具体实现是依赖ZMQ实现的ZMQVan
  • Customer用来通信,跟踪request和response。每一个连接对应一个Customer实例,连接对方的id和Customer实例的id相同。
  • SimpleApp是一个基类;提供了发送接收int型的head和string型的body消息,以及注册消息处理函数。它有2个派生类。
  • KVServer是SimpleApp的派生类,用来保存key-values数据。
  • KVWorker是SimpleApp的派生类,用来想Server Push/Pull key-value数据。
  • KVPairs封装了Key-Value结构,还包含了一个长度选项。
  • SArray是Shared array,像智能指针一样共享数据,接口类似vector。
  • Node封装了节点的信息,例如角色、ip、端口、是否是恢复节点。
  • Control封装了控制信息,例如命令类型、目的节点、barrier_group的id、签名。
  • Meta封装了元数据,发送者、接受者、时间戳、请求还是相应等。
  • Message是要发送的信息,除了元数据外,还包括发送的数据。

节点角色ID

三种节点,从上图可以看出Scheduler节点只有一个,多个Worker和多个Server可以组成一个Group,因此有WorkerGroup和ServerGroup;还有Worker节点和Server节点。每个节点以及每一个Group都有唯一确定的ID。
Scheduler、ServerGroup、WorkerGroup节点ID确定如下:

static const int kScheduler = 1;
static const int kServerGroup = 2;
static const int kWorkerGroup = 4;

1、2、4的二进制表示分别为:001、010、001。这样可以做Group之间的合并,例如要和ServerGroup和WorkerGroup发信息,只需要destination node id设为2+4=6。
1-7用来表示节点的组合。单个节点的ID从8开始。单个Server和单个Worker节点从自己的rank(0、1、2……)转换到其ID:

static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
}
static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
}

ID到其rank转换:

static inline int IDtoRank(int id) {
    return std::max((id - 8) / 2, 0);
}

Postofficetd::unordered_map> node_ids_保存了Node/NodeGroup与连接节点集合的对应关系。

消息封装

  • 首先使用了自定义的SArray,Smart Array。共享数据,减少数据拷贝,且提供了类似vector的接口。
  • 元数据Meta使用了Protobuf,进行了数据压缩。
  • 消息分层比较清晰。Node包含节点的角色、id、ip、端口信息;Control包含了命令信息、签名等;Meta是元数据,包含时间戳、发送者、接受者、控制信息等;Message才是发送的信息,包含元数据和发送的数据。
  • 参数有key-value组成,对应KVPairs

通信机制

Scheduler节点管理所有节点的地址。每个节点要知道Scheduler节点的IP、port;启动时绑定一个本地端口,并向Scheduler节点报告。Scheduler节点在每个几点启动后,给节点分配ID,把节点信息通知出去(例如Worker节点要知道Server节点IP和端口,Server节点要知道Worker节点的IP和端口)。节点在建立连接后,才会正式启动。

同步策略

异步工作时,Worker计算参数可能要依赖前面Pull是否完成。如果需要等待某一步操作,可以调用SimpleApp::Wait操作。具体实现是调用了Customer::WaitRequest(),它会跟踪request和response数量是否相同,直到相同才会返回;tracker_类型为std::vector>,记录了request和response数量,这个数据结构一直增长,会造成内存一直增长。

消息处理流程

每个节点都监听了本地一个端口;该连接的节点在启动时已经连接。
对于Server节点:
1、Van::Receiving()函数是单独一个线程来接收数据。数据接收后,根据不同命令执行不同动作,例如Control::ADD_NODE就是添加节点。如果需要下一步处理,会将消息传递给Customer::Accept函数。
2、Customer::Accept()函数将消息添加到一个队列recv_queue_Customer::Receiving()是一个线程在运行,从队列取消息处理;处理过程中会使用函数对象recv_handle_处理消息,这个函数对象是SimpleApp::Process函数。
3、SimpleApp::Process根据是消息类型(请求or响应,调用用户注册的函数来处理消息,request_handle_response_handle_分别处理请求和响应。

对于Worker节点,上面第3点略有不同。因为Worker都是通过PushPull来通信,而且参数都是key-value对。Pull·参数时,通过KVWorker::Process调用回调函数来处理消息。

调试及启动流程

PS Lite通过环境变量和外界交互。

启动流程:
1、首先启动Scheduler节点。这是要固定好Server和Worker数量。
2、启动Worker或Server节点。启动时连接Scheduler节点,绑定本地端口,并向Scheduler节点注册自己信息。
3、Scheduler等待所有Worker节点都注册后,给其分配id,并把节点信息传送出去。此时Scheduler节点已经准备好。
4、Worker或Server接收到Scheduler传送的信息后,建立对应节点的连接。此时Worker或Server已经准备好。

调试时,通过环境变量来控制调试日志。
PS_VERBOSE=1,会打印连接日志。
PS_VERBOSE=2,会打印所有数据通信日志。

一个例子

参考源码给出的例子。KVPair中一个key对应多个value,具体数量在lens中记录。Server使用key-vector映射存储数据。Server收到的Push数据,只是将对应key的值相加。最后Worker Pull的数据,按照key打印。

#include 
#include "ps/ps.h"
using namespace std;
using namespace ps;
template <class Val>
class MyKVServerHandle {
public:
    void operator() (const KVMeta& req_meta, const KVPairs& req_data, KVServer* server) {
        size_t n = req_data.keys.size();
        KVPairs res;
        if (req_meta.push) { // push
            CHECK_EQ(n, req_data.lens.size());
        } else {            // pull
            res.keys = req_data.keys;
            res.lens.resize(res.keys.size());
        }

        size_t cur_idx = 0;
        for (size_t i = 0;i < n; ++i) {
            Key key = req_data.keys[i];
            if(req_meta.push){ //push
                int len = req_data.lens[i];
                if(store.count(key) == 0){//第一次push,开辟空间
                    store[key] = vector(len, 0);
                }

                for(int idx = 0; idx < len; ++idx){
                    store[key][idx] += req_data.vals[cur_idx++];
                }
            }
            else{ // pull
                res.lens[i] = store[key].size();
                for(int idx = 0; idx < res.lens[i]; ++idx){
                    res.vals.push_back(store[key][idx]);
                }
            }
        }
        server->Response(req_meta, res);
    }

private:
    std::unordered_mapvector> store;
};
void StartServer() {
    if (!IsServer()) return;
    cout << "num of workers[" << NumWorkers() << "]" << endl;
    cout << "num of servers[" << NumServers() << "]" << endl;
    auto server = new KVServer<float>(0);
    server->set_request_handle(MyKVServerHandle<float>());
    RegisterExitCallback([server](){ delete server; });
}
void RunWorker() {
    if (!IsWorker()) return;
    cout << "start Worker rank = " << MyRank() << endl;
    KVWorker<float> kv(0);
    // init
    int key_num = 10;
    int val_num = 0;
    vector keys(key_num);
    vector<int> len(key_num);
    for(int i = 0; i < key_num; ++i){
        keys[i] = i;
        len[i] = i + 1;
        val_num += len[i];
    }

    vector<float> vals(val_num);
    for (int i = 0;i < val_num; ++i) {
        vals[i] = i / 10;
    }
    // push
    int repeat = 10;
    vector<int> ts;
    for (int i = 0;i < repeat; ++i) {
        ts.push_back(kv.Push(keys, vals, len));  
    }
    for (int t : ts) kv.Wait(t);
    // pull
    std::vector<float> ret_val;
    std::vector<int> ret_len;
    kv.Wait(kv.Pull(keys, &ret_val, &ret_len));
    CHECK_EQ(keys.size(), ret_len.size());

    size_t cur_idx = 0;
    for (size_t i = 0;i < keys.size(); ++i) {
        std::cout<" key ["<"] vals [";
        for(int idx = 0; idx < ret_len[i]; ++idx){
            std::cout<<" "<std::cout<<"]"<<std::endl;
    }
    cout << endl;
}
int main(int argc, char* argv[]) {
    StartServer();
    Start(); 
    RunWorker();
    Finalize(); 
    return 0;
}

一个Server,两个Worker

./local.sh 1 2 ./test_example
num of workers[2]
num of servers[1]
start Worker rank = 0
start Worker rank = 1
0 key [0] vals [ 0]
0 key [1] vals [ 0 0]
0 key [2] vals [ 0 0 0]
0 key [3] vals [ 0 0 0 0]
0 key [4] vals [ 20 20 20 20 20]
0 key [5] vals [ 20 20 20 20 20 40]
1 key [0] vals [ 0]
0 key [6] vals [ 40 40 40 40 40 40 40]
1 key [1] vals [ 0 0]
1 key [2] vals [ 0 0 0]
0 key [7] vals [ 40 40 60 60 60 60 60 60]
1 key [3] vals [ 0 0 0 0]
0 key [8] vals [ 60 60 60 60 80 80 80 80 80]
1 key [4] vals [ 20 20 20 20 20]
0 key [9] vals [ 80 80 80 80 80 100 100 100 100 100]
1 key [5] vals [ 20 20 20 20 20 40]

1 key [6] vals [ 40 40 40 40 40 40 40]
1 key [7] vals [ 40 40 60 60 60 60 60 60]
1 key [8] vals [ 60 60 60 60 80 80 80 80 80]
1 key [9] vals [ 80 80 80 80 80 100 100 100 100 100]

参考:

Parameter Server for Distributed Machine Learning
PS-Lite Documents
ps-lite源码剖析

你可能感兴趣的:(MXNet)