FATE系统主要支持表格数据作为其标准数据格式。然而,通过使用NN模块的数据集特性,可以在神经网络中使用非表格数据,例如图像、文本、混合数据或关系数据。NN模块中的数据集模块允许定制数据集,以用于更复杂的数据场景。本教程将介绍Homo NN模块中数据集功能的使用,并提供如何自定义数据集的指导。我们将使用MNIST手写识别任务作为示例来说明这些概念。
请从以下链接下载MNIST数据集,并将其放在项目示例/数据文件夹中:MNIST
这是MNIST数据集的简化版本,共有十个类别,根据标签分为0-9 10个文件夹。我们对数据集进行采样以减少样本数量。
MNIST数据集的来源是:http://yann.lecun.com/exdb/mnist/
在FATE-1.10版本中,FATE为数据集引入了一个新的基类,称为Dataset,它基于PyTorch的Dataset类。此类允许用户根据其特定需求创建自定义数据集。其用法与PyTorch的Dataset类类似,在使用FATE-NN进行数据读取和训练时,需要实现两个额外的接口:load()和get_sample_ids()。
开发继承自dataset类的新数据集类
实现__len__()和__getitem__()方法,它们与PyTorch的数据集用法一致。__len__()方法应返回数据集的长度,而__getitem_()方法则应返回指定索引处的相应数据
实现load()和get_sample_ids()方法
对于不熟悉PyTorch的数据集类的人,可以在PyTorch文档中找到更多信息:PyTorch数据集文档
所需的第一个附加接口是load()。此接口接收文件路径,并允许用户直接从本地文件系统读取数据。提交任务时,可以通过读取器组件指定数据路径。Homo NN将使用用户指定的Dataset类,利用load()接口从指定路径读取数据,并完成数据集的加载以进行训练。有关更多信息,请参阅/federatedml/nn/dataset/base.py中的源代码。
第二个附加接口是get_sample_ids()。此接口应返回一个样本ID列表,该列表可以是整数或字符串,并且长度应与数据集相同。实际上,当使用Homo NN时,您可以跳过实现这个接口,因为Homo NN组件将自动为样本生成ID。
为了更好地理解数据集的定制,我们在这里实现了一个简单的图像数据集来读取MNIST图像,然后在横向场景中完成联合图像分类任务。为了方便起见,我们使用save_to_rate的jupyter接口来更新代码以federatedml.nn.Dataset,名为MNIST_Dataset.py,当然,您可以手动将代码文件复制到目录中
from pipeline.component.nn import save_to_fate
这里我们实现了数据集,并使用save_to_date()保存它。
%%save_to_fate dataset mnist_dataset.py
import numpy as np
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
class MNISTDataset(Dataset):
def __init__(self, flatten_feature=False): # flatten feature or not
super(MNISTDataset, self).__init__()
self.image_folder = None
self.ids = None
self.flatten_feature = flatten_feature
def load(self, path): # read data from path, and set sample ids
# read using ImageFolder
self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))
# filename as the image id
ids = []
for image_name in self.image_folder.imgs:
ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))
self.ids = ids
return self
def get_sample_ids(self): # implement the get sample id interface, simply return ids
return self.ids
def __len__(self,): # return the length of the dataset
return len(self.image_folder)
def __getitem__(self, idx): # get item
ret = self.image_folder[idx]
if self.flatten_feature:
img = ret[0][0].flatten() # return flatten tensor 784-dim
return img, ret[1] # return tensor and label
else:
return ret
在我们实现数据集之后,我们可以在本地测试它:
from federatedml.nn.dataset.mnist_dataset import MNISTDataset
ds = MNISTDataset(flatten_feature=True)
# load MNIST data and check
ds.load('/mnt/hgfs/YOLOV5/mnist/') # 切换成自己下载上文中minist文件夹的地址
print(len(ds))
print(ds[0])
print(ds.get_sample_ids()[0])
在提交任务之前,可以在本地进行测试。正如我们在2.1 Homo NN 二进制分类任务中提到的,在Homo NN中,FATE默认使用fedavg_trainer。自定义数据集、模型和训练器可用于本地调试,以测试程序是否正确运行。请注意,在本地测试期间,将跳过所有联合过程,并且模型不会执行联合平均。
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer
trainer = FedAVGTrainer(epochs=3, batch_size=256, shuffle=True, data_loader_worker=8, pin_memory=False) # set parameter
trainer.local_mode() # !! Be sure to enable local_mode to skip the federation process !!
import torch as t
from pipeline import fate_torch_hook
fate_torch_hook(t)
# our simple classification model:
model = t.nn.Sequential(
t.nn.Linear(784, 32),
t.nn.ReLU(),
t.nn.Linear(32, 10),
t.nn.Softmax(dim=1)
)
trainer.set_model(model) # set model
optimizer = t.optim.Adam(model.parameters(), lr=0.01) # optimizer
loss = t.nn.CrossEntropyLoss() # loss function
trainer.train(train_set=ds, optimizer=optimizer, loss=loss) # use dataset we just developed
在Trainer的train()函数中,将使用Pytorch DataLoader迭代数据集。程序可以正确运行!现在我们可以提交联合任务了。
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
t = fate_torch_hook(t)
这里,我们使用pipeline将路径绑定到名称和命名空间。然后,我们可以使用读取器组件将此路径传递到数据集的“加载”接口。培训师将在train()中获取此数据集,并使用Pytorch Dataloader对其进行迭代。请注意,在本教程中,我们使用的是独立版本,如果您使用的是集群版本,则需要将数据与每台计算机上的相应名称和命名空间绑定。
import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('../')
host = 10000
guest = 9999
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
arbiter=arbiter)
data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}
# 这里需要根据自己得版本作出调整,否则文件参数上传失败会报错
data_path_0 = fate_project_path + '/examples/data/mnist_train'
data_path_1 = fate_project_path + '/examples/data/mnist_train'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)
{'namespace': 'experiment', 'table_name': 'mnist_host'}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)
使用dataset_name指定数据集的模块名称,并在后面填写其参数,这些参数将传递给数据集的__init__接口。请注意,数据集参数需要是JSON可序列化的,否则pipeline无法解析它们。
from pipeline.component.nn import DatasetParam
dataset_param = DatasetParam(dataset_name='mnist_dataset', flatten_feature=True) # specify dataset, and its init parameters
from pipeline.component.homo_nn import TrainerParam # Interface
# our simple classification model:
model = t.nn.Sequential(
t.nn.Linear(784, 32),
t.nn.ReLU(),
t.nn.Linear(32, 10),
t.nn.Softmax(dim=1)
)
nn_component = HomoNN(name='nn_0',
model=model, # model
loss=t.nn.CrossEntropyLoss(), # loss
optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer
dataset=dataset_param, # dataset
trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),
torch_seed=100 # random seed
)
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))
pipeline.compile()
pipeline.fit()
pipeline.get_component('nn_0').get_output_data()
pipeline.get_component('nn_0').get_summary()
{'best_epoch': 1,
'loss_history': [3.58235876026547, 3.4448592824914055],
'metrics_summary': {'train': {'accuracy': [0.25668449197860965,
0.4950343773873186],
'precision': [0.3708616690797323, 0.5928620913124757],
'recall': [0.21817632850241547, 0.4855654369784805]}},
'need_stop': False}