Parameter Server中文名称叫做参数服务器,是分布式机器学习框架中用来做参数同步的框架。具体介绍可以参考后面链接,这里主要学习一下其实现。
ps-lite是Paramter Server的实现的一个框架,其中参数处理具体相关策略需自己实现。
Parameter Server包含三种角色:Worker、Server、Scheduler。具体关系如下图:
push
到Server,同时从Serverpull
参数回来。简单看一下各个类以及它们之间的关系
std::unordered_map senders_
保存了node_id到连接的映射。Van只是定义了接口,具体实现是依赖ZMQ实现的ZMQVan
。三种节点,从上图可以看出Scheduler节点只有一个,多个Worker和多个Server可以组成一个Group,因此有WorkerGroup和ServerGroup;还有Worker节点和Server节点。每个节点以及每一个Group都有唯一确定的ID。
Scheduler、ServerGroup、WorkerGroup节点ID确定如下:
static const int kScheduler = 1;
static const int kServerGroup = 2;
static const int kWorkerGroup = 4;
1、2、4的二进制表示分别为:001、010、001。这样可以做Group之间的合并,例如要和ServerGroup和WorkerGroup发信息,只需要destination node id设为2+4=6。
1-7用来表示节点的组合。单个节点的ID从8开始。单个Server和单个Worker节点从自己的rank(0、1、2……)转换到其ID:
static inline int WorkerRankToID(int rank) {
return rank * 2 + 9;
}
static inline int ServerRankToID(int rank) {
return rank * 2 + 8;
}
ID到其rank转换:
static inline int IDtoRank(int id) {
return std::max((id - 8) / 2, 0);
}
Postoffice
中td::unordered_map
保存了Node/NodeGroup与连接节点集合的对应关系。
SArray
,Smart Array。共享数据,减少数据拷贝,且提供了类似vector的接口。Meta
使用了Protobuf,进行了数据压缩。Node
包含节点的角色、id、ip、端口信息;Control
包含了命令信息、签名等;Meta
是元数据,包含时间戳、发送者、接受者、控制信息等;Message
才是发送的信息,包含元数据和发送的数据。KVPairs
。Scheduler节点管理所有节点的地址。每个节点要知道Scheduler节点的IP、port;启动时绑定一个本地端口,并向Scheduler节点报告。Scheduler节点在每个几点启动后,给节点分配ID,把节点信息通知出去(例如Worker节点要知道Server节点IP和端口,Server节点要知道Worker节点的IP和端口)。节点在建立连接后,才会正式启动。
异步工作时,Worker计算参数可能要依赖前面Pull
是否完成。如果需要等待某一步操作,可以调用SimpleApp::Wait
操作。具体实现是调用了Customer::WaitRequest()
,它会跟踪request和response数量是否相同,直到相同才会返回;tracker_
类型为std::vector
,记录了request和response数量,这个数据结构一直增长,会造成内存一直增长。
每个节点都监听了本地一个端口;该连接的节点在启动时已经连接。
对于Server节点:
1、Van::Receiving()
函数是单独一个线程来接收数据。数据接收后,根据不同命令执行不同动作,例如Control::ADD_NODE
就是添加节点。如果需要下一步处理,会将消息传递给Customer::Accept
函数。
2、Customer::Accept()
函数将消息添加到一个队列recv_queue_
;Customer::Receiving()
是一个线程在运行,从队列取消息处理;处理过程中会使用函数对象recv_handle_
处理消息,这个函数对象是SimpleApp::Process
函数。
3、SimpleApp::Process
根据是消息类型(请求or响应,调用用户注册的函数来处理消息,request_handle_
、response_handle_
分别处理请求和响应。
对于Worker节点,上面第3点略有不同。因为Worker都是通过Push
、Pull
来通信,而且参数都是key-value对。Pull·参数时,通过KVWorker::Process
调用回调函数来处理消息。
PS Lite通过环境变量和外界交互。
启动流程:
1、首先启动Scheduler节点。这是要固定好Server和Worker数量。
2、启动Worker或Server节点。启动时连接Scheduler节点,绑定本地端口,并向Scheduler节点注册自己信息。
3、Scheduler等待所有Worker节点都注册后,给其分配id,并把节点信息传送出去。此时Scheduler节点已经准备好。
4、Worker或Server接收到Scheduler传送的信息后,建立对应节点的连接。此时Worker或Server已经准备好。
调试时,通过环境变量来控制调试日志。
PS_VERBOSE=1,会打印连接日志。
PS_VERBOSE=2,会打印所有数据通信日志。
参考源码给出的例子。KVPair
中一个key对应多个value,具体数量在lens中记录。Server使用key-vector映射存储数据。Server收到的Push
数据,只是将对应key的值相加。最后Worker Pull
的数据,按照key打印。
#include
#include "ps/ps.h"
using namespace std;
using namespace ps;
template <class Val>
class MyKVServerHandle {
public:
void operator() (const KVMeta& req_meta, const KVPairs& req_data, KVServer* server) {
size_t n = req_data.keys.size();
KVPairs res;
if (req_meta.push) { // push
CHECK_EQ(n, req_data.lens.size());
} else { // pull
res.keys = req_data.keys;
res.lens.resize(res.keys.size());
}
size_t cur_idx = 0;
for (size_t i = 0;i < n; ++i) {
Key key = req_data.keys[i];
if(req_meta.push){ //push
int len = req_data.lens[i];
if(store.count(key) == 0){//第一次push,开辟空间
store[key] = vector (len, 0);
}
for(int idx = 0; idx < len; ++idx){
store[key][idx] += req_data.vals[cur_idx++];
}
}
else{ // pull
res.lens[i] = store[key].size();
for(int idx = 0; idx < res.lens[i]; ++idx){
res.vals.push_back(store[key][idx]);
}
}
}
server->Response(req_meta, res);
}
private:
std::unordered_mapvector > store;
};
void StartServer() {
if (!IsServer()) return;
cout << "num of workers[" << NumWorkers() << "]" << endl;
cout << "num of servers[" << NumServers() << "]" << endl;
auto server = new KVServer<float>(0);
server->set_request_handle(MyKVServerHandle<float>());
RegisterExitCallback([server](){ delete server; });
}
void RunWorker() {
if (!IsWorker()) return;
cout << "start Worker rank = " << MyRank() << endl;
KVWorker<float> kv(0);
// init
int key_num = 10;
int val_num = 0;
vector keys(key_num);
vector<int> len(key_num);
for(int i = 0; i < key_num; ++i){
keys[i] = i;
len[i] = i + 1;
val_num += len[i];
}
vector<float> vals(val_num);
for (int i = 0;i < val_num; ++i) {
vals[i] = i / 10;
}
// push
int repeat = 10;
vector<int> ts;
for (int i = 0;i < repeat; ++i) {
ts.push_back(kv.Push(keys, vals, len));
}
for (int t : ts) kv.Wait(t);
// pull
std::vector<float> ret_val;
std::vector<int> ret_len;
kv.Wait(kv.Pull(keys, &ret_val, &ret_len));
CHECK_EQ(keys.size(), ret_len.size());
size_t cur_idx = 0;
for (size_t i = 0;i < keys.size(); ++i) {
std::cout<" key ["<"] vals [";
for(int idx = 0; idx < ret_len[i]; ++idx){
std::cout<<" "<std::cout<<"]"<<std::endl;
}
cout << endl;
}
int main(int argc, char* argv[]) {
StartServer();
Start();
RunWorker();
Finalize();
return 0;
}
一个Server,两个Worker
./local.sh 1 2 ./test_example
num of workers[2]
num of servers[1]
start Worker rank = 0
start Worker rank = 1
0 key [0] vals [ 0]
0 key [1] vals [ 0 0]
0 key [2] vals [ 0 0 0]
0 key [3] vals [ 0 0 0 0]
0 key [4] vals [ 20 20 20 20 20]
0 key [5] vals [ 20 20 20 20 20 40]
1 key [0] vals [ 0]
0 key [6] vals [ 40 40 40 40 40 40 40]
1 key [1] vals [ 0 0]
1 key [2] vals [ 0 0 0]
0 key [7] vals [ 40 40 60 60 60 60 60 60]
1 key [3] vals [ 0 0 0 0]
0 key [8] vals [ 60 60 60 60 80 80 80 80 80]
1 key [4] vals [ 20 20 20 20 20]
0 key [9] vals [ 80 80 80 80 80 100 100 100 100 100]
1 key [5] vals [ 20 20 20 20 20 40]
1 key [6] vals [ 40 40 40 40 40 40 40]
1 key [7] vals [ 40 40 60 60 60 60 60 60]
1 key [8] vals [ 60 60 60 60 80 80 80 80 80]
1 key [9] vals [ 80 80 80 80 80 100 100 100 100 100]
Parameter Server for Distributed Machine Learning
PS-Lite Documents
ps-lite源码剖析