[pysyft-002]联邦学习pysyft从入门到精通--三个节点训练一个线性分类器

import syft as sy
import torch
from torch import nn
from torch import optim

"""
https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2002%20-%20Intro%20to%20Federated%20Learning.ipynb
Part 02 - Intro to Federated Learning.ipynb
"""

"""
本例演示:
在A节点上有一个模型。B、C节点上分别有两个样本集。A节点把模型分别送到B和C节点上进行多轮训练。
本脚本运行在A节点上。
"""

#syft需要对pytorch做hook
hook = sy.TorchHook(torch)

#两个worker,每个worker是一个训练节点
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")

#数据集,data是样本属性,target是样本类别标记
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

#数据集拆分成两部分,一部分发给bob训练,一部分发给alice训练。训练出两个模型。
#bob和alice都不知道对方的模型,bot和alice是独立的。

#给bob worker的数据
data_bob = data[0:2]
target_bob = target[0:2]

#给alice worker的数据
data_alice = data[2:]
target_alice = target[2:]

#把数据发给bob和alice,返回的是指针,指向bot和alice上的数据 
p_data_bob = data_bob.send(bob)
p_data_alice = data_alice.send(alice)
p_target_bob = target_bob.send(bob)
p_target_alice = target_alice.send(alice)

#在正式环境上,可以通过其他方式上述数据的指针传过来。
#保存指针,至此,数据准备完成了,开始进行正式训练过程。
datasets = [(p_data_bob, p_target_bob), (p_data_alice, p_target_alice)]

#初始化一个线性分类器y = w_1*x_1+w_2*x_2+b
model = nn.Linear(2,1)

#sgd优化器
opt = optim.SGD(params=model.parameters(),lr=0.1)

#训练过程
def train():
    #10次迭代
    for iter in range(10):
         for data,target in datasets:
             #把上一轮训练好的模型,发给一个worker
             model.send(data.location)
             #梯度清零
             opt.zero_grad()
             #做预测
             pred = model(data)
             #计算loss
             loss = ((pred - target)**2).sum()
             #求导
             loss.backward()
             #更新模型参数
             opt.step()
             #更新模型
             model.get()
             #输出训练误差
             print(loss.get())

#运行
train()

 

你可能感兴趣的:([pysyft-002]联邦学习pysyft从入门到精通--三个节点训练一个线性分类器)