Pysyft 实现真实场景下联邦学习python3代码示例

Pysyft 实现真实场景下联邦学习python3代码示例

           \,\,\,\,\,\,\,\,\,\, 本教程使用pysyft联邦学习python3框架,基本实现了真实场景下多机器联邦学习,代码实现中采用Serverworker和Clientworker,而非Virtualworker。

1. 技术背景

           \,\,\,\,\,\,\,\,\,\, 联邦机器学习又名联邦学习,联合学习,联盟学习。联邦机器学习是一个机器学习框架,能有效帮助多个机构在满足用户隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模[百度百科]。简单的说就是在分布的数据不公布的情况下共同训练第三方模型,如图1所示。
Pysyft 实现真实场景下联邦学习python3代码示例_第1张图片

图1. 联邦学习示意图

2. 运行环境

           \,\,\,\,\,\,\,\,\,\, 本文使用纯pysyft框架,搭建基本的联邦学习python3代码框架,不包含加密处理和第三方梯度平均等复杂操作。系统linux和windows10均可,关键包版本如图2所示。在阅读本文之前请移步学习pysyft的基本操作:pysyft教程
Pysyft 实现真实场景下联邦学习python3代码示例_第2张图片

图2. 联邦学习示意图

3. python3代码示例

  • 在pysyft的实现中,数据拥有者(企业AB)被视为服务器,需要部署并开启ServerWorker,并将带标签的数据上传到该worker,代码如下:
'''
--------------------------------------------------------
@File    :   server.py    
@Contact :   [email protected] or [email protected]
@License :   (C)Copyright 2017-2018, CS, WHU
@Modify Time : 2020/6/16 20:42     
@Author      : Liu Wang    
@Version     : 1.0   
@Desciption  : None
--------------------------------------------------------  
'''
import torch
import syft as sy
from syft.workers.websocket_server import WebsocketServerWorker
import sys

try:
    host = sys.argv[1]
    id = sys.argv[2]
    port = sys.argv[3]
    print(host, id, port)
except Exception as e:
    host, id, port = None, None, None
    print(str(e))
    print('run the server by: "python server.py host id port"')
    print('for example: "python server.py localhost server1 8182"')
    exit(-1)

hook = sy.TorchHook(torch)
server_worker = WebsocketServerWorker(host=host,  # host="192.168.2.101", # the host of server machine
                                      hook=hook, id=id, port=port)
# hook = sy.TorchHook(torch, local_worker=server_worker)

# data in server
x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True).tag("toy", "data")
y = torch.tensor([[0],[0],[1],[1.]], requires_grad=True).tag("toy", "target")
# x.private, x.private = True, True

x_ptr = x.send(server_worker)
y_ptr = y.send(server_worker)
print(x_ptr, y_ptr)

# x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=False)
# y = torch.tensor([[0],[0],[1],[1.]], requires_grad=False)
# server_worker.add_dataset(sy.BaseDataset(data=x, targets=y), key="vectors")

print('>>> server_worker:', server_worker)
print('>>> server_worker.list_objects():', server_worker.list_objects())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())

server_worker.start()  # Might need to interrupt with `CTRL-C` or some other means

print('>>> server_worker.list_objects()', server_worker.list_objects())
print('>>> server_worker.objects_count()', server_worker.objects_count())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())
print('>>> server_worker.host', server_worker.host)
print('>>> server_worker.port', server_worker.port)
  • 先开启server,然后再运行client。在pysyft的实现中,模型拥有者(数据分析者)被视为客户端,需要部署多个ClientWorker与ServerWorker一一对应,并根据服务端数据的标签查询数据并返回训练数据的指针,训练过程与教程中示例类。代码如下:
'''
--------------------------------------------------------
@File    :   client.py    
@Contact :   [email protected] or [email protected]
@License :   (C)Copyright 2017-2018, CS, WHU
@Modify Time : 2020/6/16 20:42     
@Author      : Liu Wang    
@Version     : 1.0   
@Desciption  : None
--------------------------------------------------------  
'''
import torch
from torch import optim
import syft
# from syft.grid.public_grid import PublicGridNetwork
from syft.workers.websocket_client import WebsocketClientWorker
hook = syft.TorchHook(torch)

def train(model, datasets, ITER=20)->torch.nn.Module:
    """
    :param model: the torch model
    :param datasets: the datasets pointers about server workers
            with the format as [(data_ptr, target_ptr), (data_ptr, target_ptr), ...]
    :param ITER: the number of iteration
    :return:
    """
    model_c = model.copy()
    # Training Logic
    for iter in range(ITER):
        for data, target in datasets:
            # 1) send model to correct worker
            model_c = model_c.send(data.location)
            # 2) Call the optimizer for the worker using get_optim
            opt = optim.SGD(params=model_c.parameters(),lr=0.1)
            # 3) erase previous gradients (if they exist)
            opt.zero_grad()
            # 4) make a prediction
            pred = model_c(data)
            # 5) calculate how much we missed
            loss = ((pred - target)**2).sum()
            # 6) figure out which weights caused us to miss
            loss.backward()
            # 7) change those weights
            opt.step()
            # 8) get model (with gradients)
            model_c = model_c.get()
            # 9) print our progress
            print(data.location.id, loss.get())
    return model_c

if __name__ == '__main__':
    # create a client workers mapping to the server workers in remote machines
    remote_client_1 = WebsocketClientWorker(
        host='localhost',
        # host = '192.168.0.102', # the host of remote machine, the same as the Server host
        hook=hook,
        id='server1',
        port=8182)
    remote_client_2 = WebsocketClientWorker(
        host='localhost',
        # host = '192.168.0.102', # the host of remote machine, the same as the Server host
        hook=hook,
        id='server2',
        port=8183)
    remote_clients_list = [remote_client_1, remote_client_2]
    print('>>> remote_client_1', remote_client_1)
    print('>>> remote_client_2', remote_client_2)

    # get the data pointers which point to the real data in remote machines for training model
    datasets = []
    for remote_client in remote_clients_list:
        data = remote_client.search(["toy", "data"])[0]
        target = remote_client.search(["toy", "target"])[0]
        print('>>>data: ', data)
        print('>>>target: ', target)
        datasets.append((data, target))
    # exit(0)
    # define torch model
    model = torch.nn.Linear(2, 1)
    print('>>> untrained model: ', model.state_dict())
    # train model
    trained_model = train(model, datasets, ITER=10)
    print('>>> trained model: ', trained_model.state_dict())

4. 总结

          \,\,\,\,\,\,\,\,\, 本教程使用pysyft联邦学习python3框架,基本实现了多机器联邦学习。但本教程示例存在严重缺陷:在获取数据指针后,客户端能够通过指针的.get()函数从服务端拿到数据。为了避免这种情况,需要用到pygrid包,并设置数据的权限。有时间将后续更新。

你可能感兴趣的:(联邦学习,pytorch,机器学习,深度学习,python,其他)