FATE —— 二.4.1 联邦Rensnet关于CIFAR-10的训练

在这个示例中,我们向您展示了如何使用torchvision模型来执行联合分类任务

数据集:CIFAR 10

您可以通过以下链接下载CIFAR-10数据集:CIFAR-10

CIFAR-10来源于:地址

为了便于演示,我们的客户将使用相同的数据集

本地测试

首先,我们在本地测试我们的模型和数据集。如果它有效,我们可以提交联合任务。

from pipeline.component.nn import save_to_fate
%%save_to_fate model resnet.py

# model
import torch as t
from torch import nn
from torchvision.models import resnet18, ResNet18_Weights

class Resnet(nn.Module):

    def __init__(self, ):
        super(Resnet, self).__init__()
        self.resnet = resnet18()
        self.classifier = t.nn.Linear(1000, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        if self.training:
            return self.classifier(self.resnet(x))
        else:
            return self.softmax(self.classifier(self.resnet(x)))
model = Resnet()
print(model)
FATE —— 二.4.1 联邦Rensnet关于CIFAR-10的训练_第1张图片
# read dataset
from federatedml.nn.dataset.image import ImageDataset

ds = ImageDataset()
ds.load('/mnt/hgfs/cifar-10/cifar10/train/')
ds[0][0].shape

torch.Size([3, 32, 32])

# local test
import torch as t
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer

trainer = FedAVGTrainer(epochs=1, batch_size=1024, data_loader_worker=4)
trainer.set_model(model)

optimizer = t.optim.Adam(model.parameters(), lr=0.001)
loss = t.nn.CrossEntropyLoss()

trainer.local_mode() # set local mode
trainer.train(ds, None, optimizer, loss)

提交联邦任务

import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

fate_torch_hook(t)

import os
# fate_project_path = os.path.abspath('../../../../')
guest = 10000
host = 9999
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                            arbiter=host)
data_0 = {"name": "cifar10", "namespace": "experiment"}
# 根据自己的文件路径进行设置
data_path = '/mnt/hgfs/cifar-10/cifar10/train/'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)

{'namespace': 'experiment', 'table_name': 'cifar10'}

reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_0)

reader_1 = Reader(name="reader_1")
reader_1.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_1.get_party_instance(role='host', party_id=host).component_param(table=data_0)
from pipeline.component.homo_nn import DatasetParam, TrainerParam

model = t.nn.Sequential(
    t.nn.CustModel(module_name='resnet', class_name='Resnet')
)

nn_component = HomoNN(name='nn_0',
                      model=model, 
                      loss=t.nn.CrossEntropyLoss(),
                      optimizer = t.optim.Adam(lr=0.001, weight_decay=0.001),
                      dataset=DatasetParam(dataset_name='image'),  # 使用自定义的dataset
                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=10, batch_size=1024, data_loader_worker=8),
                      torch_seed=100
                      )
pipeline.add_component(reader_0)
pipeline.add_component(reader_1)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data, validate_data=reader_1.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))
pipeline.compile()
pipeline.fit() # submit pipeline here
FATE —— 二.4.1 联邦Rensnet关于CIFAR-10的训练_第2张图片
pipeline.get_component('nn_0').get_output_data()  # get result
FATE —— 二.4.1 联邦Rensnet关于CIFAR-10的训练_第3张图片
保存结果并查看损失过程
import pandas as pd
df = pipeline.get_component('nn_0').get_output_data()  # get result
df.to_csv('联邦Rensnet关于CIFAR-10的训练.csv')

pipeline.get_component('nn_0').get_summary()
FATE —— 二.4.1 联邦Rensnet关于CIFAR-10的训练_第4张图片

你可能感兴趣的:(联邦学习,深度学习,人工智能,python,算法)