FedAVGTrainer
import torch
import torch as t
import tqdm
import numpy as np
from torch.utils.data import DataLoader
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorServer as SecureAggServer
from federatedml.nn.dataset.base import Dataset
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from federatedml.util import LOGGER, consts
from federatedml.optim.convergence import converge_func_factory
class FedAVGTrainer(TrainerBase):
"""
Parameters
----------
epochs: int >0, epochs to train
batch_size: int, -1 means full batch
secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
mask to local models. These random number masks will eventually cancel out to get 0.
weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
is the sample number locally and n_global is the sample number of all clients.
if False, simply averaging these models.
early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
stop training
tol: float, tol value for early stop
aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
every n epochs.
cuda: bool, use cuda or not
pin_memory: bool, for pytorch DataLoader
shuffle: bool, for pytorch DataLoader
data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
task_type: str, 'auto', 'binary', 'multi', 'regression'
this option decides the return format of this trainer, and the evaluation type when running validation.
if auto, will automatically infer your task type from labels and predict results.
"""
def __init__(self, epochs=10, batch_size=512, # training parameter
early_stop=None, tol=0.0001, # early stop parameters
secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None, # federation
cuda=False, pin_memory=True, shuffle=True, data_loader_worker=0, # GPU & dataloader
validation_freqs=None, # validation configuration
checkpoint_save_freqs=None, # checkpoint configuration
task_type='auto'
):
super(FedAVGTrainer, self).__init__()
# training parameters
self.epochs = epochs
self.tol = tol
self.validation_freq = validation_freqs
self.save_freq = checkpoint_save_freqs
self.task_type = task_type
task_type_allow = [
consts.BINARY,
consts.REGRESSION,
consts.MULTY,
'auto']
assert self.task_type in task_type_allow, 'task type must in {}'.format(
task_type_allow)
# aggregation param
self.secure_aggregate = secure_aggregate
self.weighted_aggregation = weighted_aggregation
self.aggregate_every_n_epoch = aggregate_every_n_epoch
# GPU
self.cuda = cuda
if not torch.cuda.is_available() and self.cuda:
raise ValueError('Cuda is not available on this machine')
# data loader
self.batch_size = batch_size
self.pin_memory = pin_memory
self.shuffle = shuffle
self.data_loader_worker = data_loader_worker
self.early_stop = early_stop
early_stop_type = ['diff', 'abs']
if early_stop is not None:
assert early_stop in early_stop_type, 'early stop type must be in {}, bug got {}' \
.format(early_stop_type, early_stop)
# communicate suffix
self.comm_suffix = 'fedavg'
----------
epochs:int>0,要训练的时间段
batch_size: int,-1表示完整批次
secure_aggregate: bool,默认值为True,是否使用安全聚合。如果启用,将添加随机数掩模到本地模型。这些随机数掩码最终将抵消,得到0。
weighted_gaggregation: bool,在进行聚合时是否向每个局部模型添加权重。如果为True,则根据原始纸张,客户端的权重为:n_local/n_global,其中n_local是本地的样本数,n_global是所有客户端的样本数。如果为False,则简单地对这些模型求平均值。
early_stop:无,“diff”或“abs”。如果无,则禁用提前停止;如果“diff”,则使用两个时期作为早期停止条件,如果差异<tol,停止训练;如果“abs”,如果损失 tol:浮动,提前停止的tol值 aggregate_every_n_epoch:无或int。如果无,则在每个epoch结束时聚合模型,如果int,则聚合每n个epochs。 cuda:bool,用不用cuda pin_memory: bool,用于pytorch DataLoader shuffle: bool,用于pytorch DataLoader data_loader_worker: int,对于pytorch DataLoader,加载数据时的工作者数 validation_freqs: None或int。如果为int,则验证模型并每n个epoch将验证结果发送给命运板。 如果是二进制分类任务,将使用度量“auc”、“ks”、“gain”、“lift”、“precision” 如果是多分类任务,将使用度量“精度”、“召回”、“准确度” 如果是回归任务,将使用度量“毫秒”、“mae”、“rmse”、“解释方差”、“r2_score” checkpoint_save_freqs: 每n个历元保存一次模型,如果无,则不会保存检查点。 task_type: str,'auto','binary','multi','regression'该选项决定了该培训师的返回格式,以及运行验证时的评估类型。如果自动,将自动从标签推断任务类型并预测结果。