Pysyft使用指南

文章目录

  • 项目要求
  • Pysyft
    • 版本问题
    • 基本流程
    • 遇到的问题
      • Dataset问题
      • Tensor传递问题
      • 数据冲突
      • 数据取回
      • Optimizer选择
        • Optimizer补充
    • 运行体验

项目要求

由于项目需要,最近研究了一下联邦学习框架。项目要求把新闻推荐算法部署到联邦学习框架上。我们选用了Pysyft联邦学习框架,以及微软MIND数据集及算法。

以下为利用到的开源数据和代码等。

Pysyft框架:https://github.com/OpenMined/PySyft
Pysyft0.2.x版本:https://github.com/OpenMined/PySyft/tree/syft_0.2.x
MIND数据集:https://msnews.github.io/
MIND推荐算法:https://github.com/microsoft/recommenders
MIND推荐算法torch版本:https://github.com/yusanshi/NewsRecommendation

Pysyft

版本问题

在做这个项目过程中,syft刚好从0.2版本更新到了0.3版本。两个版本之间有较大差异,0.3不兼容0.2的代码。并且0.3的文档开发者还没有完成,只有少量的examples。

基于此,我们选用了0.2.x版本的Pysyft完成我们的项目。该版本Pysyft有详细的教程。

Pysyft 0.2.x 采用python>=3.6, pytorch=1.4.0

基本流程

基本流程可以参考CNN做MNIST分类。

  1. 建立多个虚拟的workers
# NEW: import the Pysyft library
import syft as sy  
# NEW: hook PyTorch ie add extra functionalities to support Federated Learning
hook = sy.TorchHook(torch)
# NEW: define remote worker bob  
bob = sy.VirtualWorker(hook, id="bob")  
# NEW: and alice
alice = sy.VirtualWorker(hook, id="alice")  
  1. 导入数据集,并分发到各个虚拟的worker上。操作非常简单,将torch.Datasetsyft.FederatedDataLoader进行操作即可。
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)
  1. 将模型分发到各个虚拟的worker上,进行训练。其中,模型的定义方法和torch相同。
:

def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))

遇到的问题

Dataset问题

Pysyft已经完成了接收数据和分发数据的代码。但是它只支持dataset.__getitem__()返回两个值,datatarget,并且datatarget需要是tensor类型(也有可能可以是numpy等,未尝试)。

教程中给出的CNN只有一个输入,如上段所述,查看syft源码发现它只支持单输入。而新闻推荐模型有多个输入。原代码中给出的输入类型为字典。

我们对dataset类的__getitem__方法进行改写,将原本的多个输入拼成一个大的tensor输入进行返回。然后在模型中再将这个大tensor进行切片处理。此过程中需要注意tensor维度,以及在DataLoader时增加的batch_size维度。注意对模型进行相应的修改。

Tensor传递问题

需要注意的是syft传递到client端的tensor只是一个PointerTensor,而不是完整了tensor数据。

如果模型中有用到tensor.size()来确定维度的,一定要将其改掉。因为传到client的tensor的size一定是0。

另外需要注意的是,虽然传入的只是PointerTensor,我们在模型中仍然可以对其进行切片、加减乘除等各种运算。

这里也记录一个非常奇怪的错误:

import torch
a = torch.rand(2,3)
b = torch.rand(2,3)
c = a/b

在正常的torch程序中可以执行以上代码,输出对应元素相除后的tensor。但是如果除法出现在模型中,在syft框架调用训练的时候就会报RecursionError: maximum recursion depth exceeded in comparison解决办法:不要使用/作为除法,而采用torch.div()

数据冲突

使用pysyft训练的过程中,所有训练模型中用到的数据一定都要分发到client端。如果训练数据在client端,而标签在server端,会报错。

数据取回

在client端训练完成之后,需要将一些数据取回server端。采用tensor.get()方法。一定要注意,必须是后面在client端不会再使用的数据才能通过这种方法取回。因为在get()了之后会改变其本身的类型,无法进行别的操作,也无法进行二次get()。具体可以参考教程进行实验验证。

Optimizer选择

注意:Syft0.2.x只支持不含momentum的optimizer。官方解释为# TODO momentum is not supported at the moment,详见Tutorial 6。

选用Adam等优化器会在optimizer.step()处报add_() takes 1 positional argument but 2 were given错误。

只需要将optimizer改为SGD即可。

Optimizer补充

在Syft的issue里发现了关于Adam无法使用的解决办法,详见Issue #2070。

简单来讲就是对每个machine单独训练一个model,在一些step之后对模型进行聚合即可。代码可以参考Tutorial 4。

注意: Issue #2070中有一条声称可以通过修改Adam优化器源码来解决Syft无法使用Optimizer问题的,该解决办法无效。它只能解决add_() takes 1 positional argument but 2 were given错误,但是仍会报add_() tensor not on the same machine错误。

运行体验

占用显存较少,速度较慢。

你可能感兴趣的:(代码学习,深度学习)