\,\,\,\,\,\,\,\,\,\, 本教程使用pysyft联邦学习python3框架,基本实现了真实场景下多机器联邦学习,代码实现中采用Serverworker和Clientworker,而非Virtualworker。
\,\,\,\,\,\,\,\,\,\, 联邦机器学习又名联邦学习,联合学习,联盟学习。联邦机器学习是一个机器学习框架,能有效帮助多个机构在满足用户隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模[百度百科]。简单的说就是在分布的数据不公布的情况下共同训练第三方模型,如图1所示。
\,\,\,\,\,\,\,\,\,\, 本文使用纯pysyft框架,搭建基本的联邦学习python3代码框架,不包含加密处理和第三方梯度平均等复杂操作。系统linux和windows10均可,关键包版本如图2所示。在阅读本文之前请移步学习pysyft的基本操作:pysyft教程
'''
--------------------------------------------------------
@File : server.py
@Contact : [email protected] or [email protected]
@License : (C)Copyright 2017-2018, CS, WHU
@Modify Time : 2020/6/16 20:42
@Author : Liu Wang
@Version : 1.0
@Desciption : None
--------------------------------------------------------
'''
import torch
import syft as sy
from syft.workers.websocket_server import WebsocketServerWorker
import sys
try:
host = sys.argv[1]
id = sys.argv[2]
port = sys.argv[3]
print(host, id, port)
except Exception as e:
host, id, port = None, None, None
print(str(e))
print('run the server by: "python server.py host id port"')
print('for example: "python server.py localhost server1 8182"')
exit(-1)
hook = sy.TorchHook(torch)
server_worker = WebsocketServerWorker(host=host, # host="192.168.2.101", # the host of server machine
hook=hook, id=id, port=port)
# hook = sy.TorchHook(torch, local_worker=server_worker)
# data in server
x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True).tag("toy", "data")
y = torch.tensor([[0],[0],[1],[1.]], requires_grad=True).tag("toy", "target")
# x.private, x.private = True, True
x_ptr = x.send(server_worker)
y_ptr = y.send(server_worker)
print(x_ptr, y_ptr)
# x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=False)
# y = torch.tensor([[0],[0],[1],[1.]], requires_grad=False)
# server_worker.add_dataset(sy.BaseDataset(data=x, targets=y), key="vectors")
print('>>> server_worker:', server_worker)
print('>>> server_worker.list_objects():', server_worker.list_objects())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())
server_worker.start() # Might need to interrupt with `CTRL-C` or some other means
print('>>> server_worker.list_objects()', server_worker.list_objects())
print('>>> server_worker.objects_count()', server_worker.objects_count())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())
print('>>> server_worker.host', server_worker.host)
print('>>> server_worker.port', server_worker.port)
'''
--------------------------------------------------------
@File : client.py
@Contact : [email protected] or [email protected]
@License : (C)Copyright 2017-2018, CS, WHU
@Modify Time : 2020/6/16 20:42
@Author : Liu Wang
@Version : 1.0
@Desciption : None
--------------------------------------------------------
'''
import torch
from torch import optim
import syft
# from syft.grid.public_grid import PublicGridNetwork
from syft.workers.websocket_client import WebsocketClientWorker
hook = syft.TorchHook(torch)
def train(model, datasets, ITER=20)->torch.nn.Module:
"""
:param model: the torch model
:param datasets: the datasets pointers about server workers
with the format as [(data_ptr, target_ptr), (data_ptr, target_ptr), ...]
:param ITER: the number of iteration
:return:
"""
model_c = model.copy()
# Training Logic
for iter in range(ITER):
for data, target in datasets:
# 1) send model to correct worker
model_c = model_c.send(data.location)
# 2) Call the optimizer for the worker using get_optim
opt = optim.SGD(params=model_c.parameters(),lr=0.1)
# 3) erase previous gradients (if they exist)
opt.zero_grad()
# 4) make a prediction
pred = model_c(data)
# 5) calculate how much we missed
loss = ((pred - target)**2).sum()
# 6) figure out which weights caused us to miss
loss.backward()
# 7) change those weights
opt.step()
# 8) get model (with gradients)
model_c = model_c.get()
# 9) print our progress
print(data.location.id, loss.get())
return model_c
if __name__ == '__main__':
# create a client workers mapping to the server workers in remote machines
remote_client_1 = WebsocketClientWorker(
host='localhost',
# host = '192.168.0.102', # the host of remote machine, the same as the Server host
hook=hook,
id='server1',
port=8182)
remote_client_2 = WebsocketClientWorker(
host='localhost',
# host = '192.168.0.102', # the host of remote machine, the same as the Server host
hook=hook,
id='server2',
port=8183)
remote_clients_list = [remote_client_1, remote_client_2]
print('>>> remote_client_1', remote_client_1)
print('>>> remote_client_2', remote_client_2)
# get the data pointers which point to the real data in remote machines for training model
datasets = []
for remote_client in remote_clients_list:
data = remote_client.search(["toy", "data"])[0]
target = remote_client.search(["toy", "target"])[0]
print('>>>data: ', data)
print('>>>target: ', target)
datasets.append((data, target))
# exit(0)
# define torch model
model = torch.nn.Linear(2, 1)
print('>>> untrained model: ', model.state_dict())
# train model
trained_model = train(model, datasets, ITER=10)
print('>>> trained model: ', trained_model.state_dict())
\,\,\,\,\,\,\,\,\, 本教程使用pysyft联邦学习python3框架,基本实现了多机器联邦学习。但本教程示例存在严重缺陷:在获取数据指针后,客户端能够通过指针的.get()函数从服务端拿到数据。为了避免这种情况,需要用到pygrid包,并设置数据的权限。有时间将后续更新。