MXNet分布式训练怎么完成的?

MXNet

首先

incubator-mxnet/example/image-classification/中运行

python ../../tools/launch.py -n 2 -s 1 --launcher ssh -H hosts \
--sync-dst-dir /home/xugb/image-classification_test/  \
python train_mnist.py --network lenet --kv-store dist_sync

即python调用incubator-mxnet/tools/launch.py,同时传入参数

agrs function
-n 2
-s 1
--launcher ssh
-H hosts
--sync-dst-dir /home/xugb/image-classification_test/
python train_mnist.py --network lenet --kv-store dist_sync

之后解析这些参数
-n 2 -s 1 --launcher ssh -H hosts --sync-dst-dir /home/xugb/image-classification_test/

python train_mnist.py --network lenet --kv-store dist_sync
并且根据这些配置,调用相应函数。

连接各主机

incubator-mxnet/3rdparty/dmlc-core/tracker/dmlc_tracker/ssh.py

incubator-mxnet/3rdparty/dmlc-core/tracker/dmlc_tracker/tracker.py 初始化PSTracker作为PS的控制节点。然后返回调用ssh.py的函数,ssh连接其他server和worker,并运行python train_mnist.py --network lenet --kv-store dist_sync,环境和DMLC_ROLE都被传过去了。然后当前进程的控制权给PS scheduler。

各个节点的运行

incubator-mxnet/example/image-classification/train_imagenet.py根据传入的args,动态调用
incubator-mxnet/example/image-classification/symbols/lenet.py
incubator-mxnet/example/image-classification/common/date.py/得到data,
得到的参数再传入incubator-mxnet/example/image-classification/common/fit.py的fit函数。
fit则再调用incubator-mxnet/python/mxnet/model/module.py(继承incubator-mxnet/python/mxnet/model/base_module.py),接着看父类base_module的fit的文档

fit.py

  • incubator-mxnet/example/image-classification/common/fit.py
    • kvstore
      • 首先kv = mx.kvstore.create(args.kv_store),来自incubator-mxnet/python/mxnet/kvstore.pycreate()函数
      • 调用c的API,include/mxnet/c_api.hsrc/c_api/c_api.cc
    • kv.rank
      • kv.rank在python/mxnet/kvstore.py的rank()
    • model.fit() 传入kv

module.py

  • python/mxnet/module/module.py继承python/mxnet/kvstore.py,在BaseModulefit
  • BaseModulefit
    • init_params
    • init_optimizer
    • update
    • optimizer
  • update在module.py具体实现了,
    • 调用了python/mxnet/model.py_update_params_on_kvstore
      • kvstore.push(name, grad_list, priority=-index) # push grad
      • kvstore.pull(name, arg_list, priority=-index) # pull weight
      • 于是到了python/mxnet/kvstore.pypushpull函数。
        • python/mxnet/kvstore.pypushpull函数即由src/c_apiinclude/mxnet/c_api.hMXKVStorePushMXKVStorePushExMXKVStorePullMXKVStorePullEx实现
          • c_api其实是为了给其他语言提供统一接口,本质也是在调用src/kvstore下的c语言实现的kvstore。
  • python/mxnet/optimizer.pyupdate决定了如何利用grad更新weight

c++的kvstore

  • src/kvstore/kvstore.cc下,kvstore根据#if MXNET_USE_DIST_KVSTORE来判断创建KVStoreDist或者KVStoreLocal
  • src/kvstore/kvstore_dist.h是继承kvstore_local.h的,并重新实现InitImplPullImplPushImpl
  • kvstore_dist.hkvstore_dist_server.h用到pslite
  • PushImpl
    • Push_
      • PushDefault
        • ps::KVWorker*ps_worker_->ZPush()。在ps-list/include/ps/kv_app.h
          • KVWorker::Send()调用Postoffice::Get()->van()->Send(msg);
            • Van::Send(const Message& msg)调用zmq_van.hSendMsg(msg),用zmq_van.hsenders_来得到socket。而zmq_van.hsenders_void Connect(const Node& node)连接而来
  • PullImpl
    • ``

ps-lite

postoffice

  • postoffice.cc创建van_ = Van::Create("zmq");
    • void Postoffice::Start()启动Van::Start(),调用ZMQVanConnect,建立socket

KVWorker & KVServer

  • 3rdparty/ps-lite/include/ps/kv_app.h加上KVBorker
  • 更改从launch.py开始的host、role解析,调用KVBroker作为管理,更改VAN调用的ZMQVANKAFKAVAN

zmq api

https://www.cnblogs.com/fengbohello/p/4230135.html
naotu.baidu.com


cppkafka


librdkafka

src/rdkafka_transport.crd_kafka_transport_t *rd_kafka_transport_connect用broker thread建立socket连接。

你可能感兴趣的:(MXNet分布式训练怎么完成的?)