PySyft中model.get()方法解释

最近在看联邦学习相关内容,上手了PySyft框架。该框架相当于是在Pytorch深度学习框架上扩展的第三方API,对我经常使用Pytorch深度学习框架的人来说再合适不过。然后在网上找各种帖子学习了一点基础的东西,其中有一篇比较经典的帖子是:Pysyft实战4:使用pysyft在MNIST数据集上训练CNN网路。当然,我没完全复现这篇帖子上的代码,我的代码精简了,如下所示:

import torch
import syft as sy
from torch import nn
from torch import optim


# 定义训练函数
def train():
    for iter in range(2):
        # 遍历每个工作机的数据集
        for data, target in datasets:
            # 将模型发送给对应的虚拟机
            print("data: ", data)
            model.send(data.location)
            # 消除之前的梯度
            opt.zero_grad()
            # 预测
            print("data.location: ", data.location)
            print("data.location.objects: ", data.location._objects)
            pred = model(data)
            # 计算损失
            loss = ((pred - target) ** 2).sum()
            # 回传损失
            loss.backward()
            # 更新参数
            opt.step()
            # 获取模型
            model.get()
            print("model data._objects: ", data.location._objects)


if __name__ == "__main__":
    hook = sy.TorchHook(torch)

    # 分别创建虚拟工作机James和Ken
    James = sy.VirtualWorker(hook, id='James')
    Ken = sy.VirtualWorker(hook, id='Ken')

    # 创建数据
    data = torch.tensor([[0, 1], [0, 1], [1, 0], [1, 1.]], requires_grad=True)
    target = torch.tensor([[0], [0], [1], [1.]], requires_grad=True)

    # 创建模型  形如全连接层:y=a*x+b  (in_features, out_features)
    model = nn.Linear(2, 1)

    # 将训练数据发送给主工作机,数据分为两部分,分别发给James和Ken
    # data_James, target_James 均为 PointTensor 意思就是不具备本地data target数据的控制权
    data_James, target_James = data[:2].send(James), target[:2].send(James)
    data_Ken, target_Ken = data[2:].send(Ken), target[2:].send(Ken)


    # 存储张量指针
    datasets = [(data_James, target_James), (data_Ken, target_Ken)]

    # 定义优化器
    opt = optim.SGD(params=model.parameters(), lr=0.1)
    train()

如上代码是正确可运行的,但是其中有两步当时一直让我看不太懂

就是如下两步,这两步不是连续的哈,先备注下

model.send(data.location)
...
# 获取模型
model.get()

为什么觉得这块有蹊跷呢?是因为我觉得可能不需要model.get()这一步,模型都已经训练完了。我知道作者的意思,无非就是取回在work工作机端训练好的模型参数嘛(w, b),但是这一步实在是太简略了,导致我完全无法理解其取回过程。

本着追根溯源的态度,我打印出了一些关键节点的信息,并且注释掉了model.get()这一步。发现立马报错,首先展示出报错信息:

data.location.objects:  {4641078952: tensor([[0., 1.],
        [0., 1.]], requires_grad=True), 8244462310: tensor([[0.],
        [0.]], requires_grad=True), 91856073765: Parameter containing:
            tensor([[-0.3203,  0.5368]], requires_grad=True), 69906511568: 
            Parameter containing:tensor([-0.4525], requires_grad=True)}

model data._objects:  {4641078952: tensor([[0., 1.],
        [0., 1.]], requires_grad=True), 8244462310: tensor([[0.],
        [0.]], requires_grad=True), 91856073765: Parameter containing:
           tensor([[-0.3203,  0.5031]], requires_grad=True), 69906511568: 
        Parameter containing:tensor([-0.4862], requires_grad=True), 96213712738:         tensor([[0.0843], [0.0843]], grad_fn=), 86949987024: tensor(0.0142, grad_fn=)}

data.location.objects:  {60123234592: tensor([[1., 0.],
        [1., 1.]], requires_grad=True), 97986492588: tensor([[1.],
        [1.]], requires_grad=True), 87280674109: Parameter containing:
(Wrapper)>[PointerTensor | Ken:87280674109 -> James:91856073765], 33491287113: Parameter containing: (Wrapper)>[PointerTensor | Ken:33491287113 -> James:69906511568]} response = handle_func_command(command)
  
...
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

可以看到当我们运行完一轮James和Ken工作端的数据后,立马就报错了。报错原因也看到了,就是当模型在Ken工作端运行时,传入的参数应该就不是从me(服务器端)传给Ken工作端的,因为刺客模型参数的实际数据还是保存在James工作端的,所以此刻Ken工作端的参数数据已经为PointerTensor格式了,如下所示:

87280674109: 
Parameter containing: (Wrapper)>[PointerTensor | Ken:87280674109 -> James:91856073765], 33491287113: 
Parameter containing:(Wrapper)>[PointerTensor | Ken:33491287113 -> James:69906511568]

我的理解是:如果没有model.get()这一步操作,那么me(服务器端)将没有模型参数(w, b),如果还要继续model.send(),那么只能是从其他的work工作端(James)调用模型参数了,这时候已经主次混乱了(服务器和工作端的混乱),不符合标准。

以上是我的一些揣测,说实话我也没有准确的剖析。权当是抛砖引玉,如果大佬有更精确的分析,还望不吝赐教!!!

你可能感兴趣的:(pytorch,Pysyft,联邦学习,python)