"""
@Time : 2023/10/16 0016 15:17
@Auth : yeqc
"""
'''基于信任的联邦线性回归案例'''
import torch
import syft as sy
from torch import nn
# TODO:目前是CPU模式,等以后会用到GPU 修改成GPU模式
def train(num_epochs):
for iter in range(num_epochs):
for data, target in datasets:
# 把上一轮训练好的模型发给工作节点
model.send(data.location)
# 梯度清零
opt.zero_grad()
# 计算预测结果
pred = model(data)
# 计算loss
loss = ((pred - target) ** 2).sum()
# 求导
loss.backward()
# 更新模型参数
opt.step()
# 更新模型
model.get()
print(f'location = {data.location},loss = {loss.get()}')
hook = sy.TorchHook(torch)
# 两个数据拥有者Bob和Alice
bob = sy.VirtualWorker(hook, id="Bob")
alice = sy.VirtualWorker(hook, id="Alice")
data = torch.tensor([[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32, requires_grad=True)
target = torch.tensor([[0], [0], [1], [1]], dtype=torch.float32, requires_grad=True)
# 给数据拥有者Bob的数据
data_bob = data[0:2]
target_bob = target[0:2]
# 给数据拥有者Alice的数据
data_alice = data[2:]
target_alice = target[2:]
# 将数据发送给Bob和Alice
p_data_bob = data_bob.send(bob)
p_target_bob = target_bob.send(bob)
p_data_alice = data_alice.send(alice)
p_target_alice = target_alice.send(alice)
print("bob:", bob._objects)
print("alice:", alice._objects)
# 保存张量指针
datasets = [(p_data_bob, p_target_bob), (p_data_alice, p_target_alice)]
# 初始化线性回归模型,y=w1x1+w2x2+b
model = nn.Linear(3, 1)
# SGD优化器
opt = torch.optim.SGD(params=model.parameters(), lr=0.1)
if __name__ == "__main__":
num_epochs = 10
train(num_epochs)
print(model.state_dict())
运行结果: