转载请注明出处:http://blog.csdn.net/zkwdn/article/details/53840091
参数服务器的文章介绍太多了。我也只大概了解了下petuum和ps-lite
我本身对spark很熟悉,所以就从spark的角度去理解为什么会有ps这个东西呢。
spark的架构是一个driver端,多个client端。driver和每个client端都会存储所有的参数。以lr为例,
1. 在driver端对所有的参数进行初始化,假设有10000个参数,driver端会存储一个10000维的array。
2.将这10000维的array传到每一个client端
3.每个client端利用这个array和自身上的数据计算对每个参数梯度的更新。也是一个10000的array存储每个参数的梯度更新值。
4.每个client端将自己的这个10000维的array传给driver端
5.driver端调用reduce接口,将每个client的这10000维向量相加,得到每个参数的梯度更新
6.driver端根据每个梯度的更新值更新参数值。
7.driver端将更新后的参数值(10000维数组)发送到每一个client端,重复第3步。
我们可以看到2个问题。
1.Spark采用的完全是BSP协议,即第二轮迭代必须等到第一轮迭代所有的机器完成。但是每个机器的性能是不一致的,有的机器上分配的任务比较多,所以执行速度慢,有的机器分配的任务少执行快。这样执行快的机器必须等待执行慢的机器执行完,这段时间内快的机器就啥事不干,造成资源浪费。所以基于此产生了ssp协议,ssp协议即最快的机器执行的迭代次数和最慢的机器执行的迭代次数的差距不能大于某个数值t。具体的可以看下erxing老先生的几个图
这2张图应该是petuum的精髓了。
2.spark参与计算的每个机器上都存着全局所有的参数,不管是driver端还是client端。这就导致一个问题,当参数量上亿的时候,driver端和client端是没有足够的内存存下那么多参数的,所以ps的另外一个优点就是解决这个问题。首先每个client端只需要存储跟自己相关的参数(这就大大节约了非常多的内存空间),其次spark的一个driver端被分解成多个服务器端,每个服务器也只是存储跟自己相关联的client的参数。这样就解决了参数过多,内存不够的问题。上2个图,第一个图是petuum,第二个图是李牧的
上面大概阐述了ps的两个优点。接下来我已ps-lite为例,描述ps-lite中怎么写lr或者其他模型。ps-lite自带的源代码https://github.com/dmlc/wormhole/tree/master/learn/linear,不过这个太复杂,可以先从舒克飞行员的https://github.com/kunguang/Field-aware-Factorization-Machine-ps或者https://github.com/ljzzju/logistic-regression-ftrl-ps研究。
1.先看logistic-regression-ftrl-ps这个
一共6个文件,dump.cc作用:读取二进制模型文件并解析,然后存储到文件里。load_data.h:读取训练集,main.cpp:整个程序的入口,读取配置文件,启动其它程序。scheduler.h主要用来指定每个worker应该处理哪些数据的,但是这个代码并没有实现,在ps-lite中实现了。server.h相当于spark的driver端,收集上来每个worker的梯度更新值,然后更新参数。worker.h:类似于spark的client端,得到本机器上的数据对每个参数的梯度贡献,然后传给server端。
上面的文件唯一需要总结的就是server.h的架构。这个很关键,我们先想想机器学习优化时候分为以下3个部分:优化方法(sgd,adagrad,adam,ftrl等等),正则化(l1,l2),梯度值(损失函数决定的)。其中梯度值是在worker端计算好的,优化方法和正则化这2部分就要在server端实现了。
我们先来看
这个server类继承了ps:App这个顶级抽象类,实现了CreateServer和ProcessRequest两个方法。重点关注CreateServer这个函数,这里面创建了OnlineServer这个变量,传进去3个模板参数,float表示w的类型是float类型,Entry和Handle这2个是重点。这个代码里实现的是ftrl的优化算法,
这个FTRLENTRY表示的其中一个w参数的结构体,w表示参数的值,z表示上面代码中的z,sq_cum_grad表示当前这个w积累的梯度平方和。
FTRLHANDLER:目的就是计算当梯度传过来后,该怎么更新参数。以及当worker端发送过来pull请求,如何返回值。实现了push和pull两个方法,熟悉ps的人,我就不讲作用了。
对于push方法,传进去3个参数,第一个是key,没什么用,第二个是Blob结构体(一个数组结构体),Blob结构体的作用是用来在server和worker之间传输的。push方法是server端接受到worker端push方法后所做的处理,Blob结构体就是worker短传送过来的该参数的梯度,val是在当前server中存储的该w参数的历史数据。然后里面的方法就是用上面的公式更新该参数w的值。
对于pull方法 ,作用就是返回值。val是当前参数w在服务器上的数据,send就是返回给worker,lamda1,lambda2就是正则值。
OK,到目前为止,解析完了当前代码。这个代码非常的简洁,对我这种初学者理解Ps的工作方式非常有帮助。接下来我们来介绍ps-lite自带的lr系统,就比较有模块化的感觉,方便我们以后可以更好在这个基础上写其他模型。架构都和上面一样。只是这个模块的代码抽象了一层有一层,复杂的不能再复杂了。。。
scheduler,server,worker这3个类都要继承ps:App这个顶级类
首先分析scheduler的类型继承链:ps:App->DataParScheduler->IterScheduler->MinibatchScheduler->AsgdScheduler
其次worker的类型继承链:ps:App->DataParWorker->IterWorker->MinibatchWorker->AsgdWorker
其次server的ps:App->IterServer->MinibatchServer->AsgdServer
其他的继承链:DataParCmd->IterCmd。表明这是个什么命令,是存储还是解析等等。
其他类:在base文件夹中,arg_parser.h这是解析配置文件的程序,workload.h和workload_pool.h是读取训练集的文件。
OK,总的架构是这样的。其实上面那个简单的小例子都是直接继承的ps:App这个类,不像李牧自己写的,中间加了好多类。最后模型的解析可以完全参考TLoad这个函数。也可以参考上面简单例子的dump.cc
首先查看scheduler类
这个类的作用,主要是调度,即指定哪个worker处理哪几个partition文件。
再次来看worker类,这个和上面的小例子没有太大区别。但是多了个亮点,就是李沐在论文中说的一些trick,比如某些梯度是0的参数就没必要发送到server端,以及除了第一次以后,就不再传输Key,对于这些trick,专门建立了个抽象类ps::Filter。并且将这些filter类加到worker中
最后重点再来看server类,这个就很重点了,我们来看asgdserver这个类,最大的区别在于
尤其是这个KVStore类,这应该是ps-lite特殊封装的一个类,它是一种server类型,这个地方的逻辑有点绕。大家可以去看源代码。
首先描述KVStore这个抽象类,它有两个实例类,
,这个抽象类有一个变量std::vector
这个变量其实就是最关键的变量,存储(k,v)对,k就是参数w的编号,v就是该参数对应的结构体sgdentry。
上面我们在看到创建这个server_变量时,传入了两个模板变量:Entry和Handler。其中Entry就是每个参数w对应的结构体,Handler就是指采用了哪种优化方法进行优化比如是ftrl还是sgd。
然后这个类有个HandlePush和HandlePull两个方法。而传到这个类里面的模板变量Handler,比如是sgdhandler中又2个方法,push和pull,这个push是指针对其中一个参数w的更新方法,pull是获取其中一个参数w的方法。而HandlerPush和HandlerPull则是处理worker节点传递过来的所有参数节点.大概是这么个样子
遍历data_vectory中的每个参数,然后分别调用push或者Pull处理。就这块有点绕,其他都还好。。
结束语
其实对我来讲,我大概了解下原理,然后知道怎么用起来,怎么去写新的模型才是最重要的。比如我要写个libfm模型,其实我只用修改worker里的gradient计算公式就可以了。如果想修改读取训练集的格式,就用修改workload.h中的文件就行。大部分的机器学习算法应该都只用修改这2点就可以了。其他的一些细节,比如哪个worker就行和哪个server配对。key是怎么划分的,接下来继续研究。。。。