pysyft是基于pytorch的一个联邦学习框架(虽然用起来很难受),通过内存管理实现联邦学习的模拟。
在pysyft中,WebsocketServerWorker充当数据的提供方(数据存储方),而WebsocketClientWorker作为数据的使用方(指令提供方),通过WebsocketClientWorker以TCP连接的方式向WebsocketServerWorker请求服务,从而实现分布式训练。
笔者的pysyft版本为0.2.0(较为经典的版本),在编写代码时,想要使用WebsocketServerWorker向WebsocketClientWorker发送数据,但是网上没有相关文档描述,且GPT胡言乱语,所以在阅读pysyft的源码后,记录一下心得。
在实例化client类后,发现remote相关的函数非常多,但是不知道从哪里下手处理。对于WebsocketServerWorker与WebsocketClientWorker的通信模型,有一个通用的代码框架。
WebsocketServerWorker:
# 建立server
worker = WebsocketServerWorker(**kwargs)
# 建立数据集
dataset = sy.BaseDataset(data, target)
# 将数据集添加到server中
worker.add_dataset(dataset, key="xor")
WebsocketClientWorker:
# 初始化模型
traced_model = th.jit.trace(model, mock_data)
# 配置训练采纳数
train_config = sy.TrainConfig(model=traced_model,
loss_fn=loss_fn,
optimizer=optimizer,
batch_size=batch_size,
optimizer_args=optimizer_args,
epochs=epochs,
shuffle=shuffle)
# 将配置通信参数
kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
train_config.send(alice)
# 开始远程训练
for epoch in range(10):
loss = alice.fit(dataset_key="xor")
这里我发现在alice.fit时,需要将dataset_key发送到server,打算从fit函数这里入手,来阅读源码。
有fit函数可以知道,这里最关键的函数是_send_msg_and_deserialize,所以继续阅读_send_msg_and_deserialize函数。
这里的代码还是有点复杂,于是笔者尝试将message类进行打印。发现message并没有包含通讯相关的内容。但是发现了_send_msg函数。(这里的serialize与deserialize函数使用来通信的序列化与反序列化)读到这里后,基本上可以肯定,这个函数使用来做端口通信的。因为只需要实现消息传输,所以不需要过度深入源码。
在_send_msg_and_deserialize的下方,我又发现了几个简单的函数,打算从这个函数入手进行分析。
果然,我猜测这里用到了python的函数反射,将字符串映射到了函数的名称。在WebsocketServerWorker中发现了list_objects函数。
到这里,基本上就可以猜想到如何实现通信了。
1、在WebsocketServerWorker或其子类中实现get_id()方法:
class CustomWebsocketServerWorker(WebsocketServerWorker):
def get_colonyId(self, *args):
return self.colony
2、在WebsocketClientWorker中请求RPC调用
class CustomWebsocketClientWorker(WebsocketClientWorker):
colony_id = -1
def get_colonyId(self, **kwargs):
return self._send_msg_and_deserialize("get_colonyId")
然后就可以实现任意参数的消息传输。