[pysyft-003]联邦学习pysyft从入门到精通--四个节点训练两个线性分类器并做简单平均

import torch
import syft as sy
import copy
hook = sy.TorchHook(torch)
from torch import nn, optim


'''
Part 4: Federated Learning with Model Averaging
http://localhost:8888/notebooks/git-home/github/PySyft/examples/tutorials/Part%2004%20-%20Federated%20Learning%20via%20Trusted%20Aggregator.ipynb
'''

"""
本例演示:
A节点运行脚本。B、C两个节点分别有样本集,各自训练一个模型。D节点把B和C节点的模型进行简单平均。
"""



#创建worker
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")


#这个数据集不适合做模型平均
##数据集
#data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
#target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

##分拆成不同的子数据集,发送给worker
#bobs_data = data[0:2].send(bob)
#bobs_target = target[0:2].send(bob)
#alices_data = data[2:].send(alice)
#alices_target = target[2:].send(alice)



#数据集
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.], [0.1,0.1],[0.1,1.1],[1.1,0.1],[1.1,1.1],], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.], [0],[0],[1],[1.]], requires_grad=True)


#分拆成不同的子数据集,发送给worker
bobs_data = data[0:4].send(bob)
bobs_target = target[0:4].send(bob)

alices_data = data[4:].send(alice)
alices_target = target[4:].send(alice)

#线性模型
model = nn.Linear(2,1)

#把模型复制到worker
bobs_model = model.copy().send(bob)
alices_model = model.copy().send(alice)

#worker的优化器
bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)
alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)


#10次迭代
for i in range(10):
    #训练bob的模型
    bobs_opt.zero_grad()
    bobs_pred = bobs_model(bobs_data)
    bobs_loss = ((bobs_pred - bobs_target)**2).sum()
    bobs_loss.backward()
    bobs_opt.step()
    bobs_loss = bobs_loss.get().data

    #训练alice的模型
    alices_opt.zero_grad()
    alices_pred = alices_model(alices_data)
    alices_loss = ((alices_pred - alices_target)**2).sum()
    alices_loss.backward()
    alices_opt.step()
    alices_loss = alices_loss.get().data

    print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))

    

#把模型移动到secure_worker做简单平均
alices_model.move(secure_worker)
bobs_model.move(secure_worker)
with torch.no_grad():
    model.weight.set_(((alices_model.weight.data+bobs_model.weight.data)/2).get())
    model.bias.set_(((alices_model.bias.data+bobs_model.bias.data)).get())


print('\n alices_model:')
print((alices_model.weight.data+0).get())
print((alices_model.bias.data+0).get())


print('\n boss_model:')
print((bobs_model.weight.data+0).get())
print((bobs_model.bias.data+0).get())


print('\nmodel:')
print(model.weight.data)
print(model.bias.data)
    


preds = model(data)
loss = ((preds - target)**2).sum()

print('\nmean result:')
print(preds)
print(target)
print(loss.data)

 

你可能感兴趣的:(联邦学习,pysyft)