当前有个工作需要实现scaffold算法,该方法通过添加修正项c来解决客户端漂移现象,
在参考github上的相关框架后,复现了该算法。
算法分为三个模块:
optimizer: 重写优化器sdg
clientscaffold:客户端操作
serverscaffold:服务端操作
optimizer部分代码:
import torch
from torch.optim import Optimizer
class SCAFFOLDOptimizer(Optimizer):
def __init__(self, params, lr, weight_decay):
defaults = dict(lr=lr, weight_decay=weight_decay)
super(SCAFFOLDOptimizer, self).__init__(params, defaults)
pass
def step(self, server_controls, client_controls, closure=None):
loss = None
if closure is not None:
loss = closure
# for group, c, ci in zip(self.param_groups, server_controls, client_controls):
# p = group['params'][0]
# if p.grad is None:
# continue
# d_p = p.grad.data + c.data - ci.data
# p.data = p.data - d_p.data * group['lr']
for group in self.param_groups:
for p, c, ci in zip(group['params'], server_controls, client_controls):
if p.grad is None:
continue
d_p = p.grad.data + c.data - ci.data #这里实现用c来更新本地模型
p.data = p.data - d_p.data * group['lr']
return loss
serverscaffold:
from flcore.clients.clientscaffold import clientScaffold
from flcore.servers.serverbase import Server
from utils.data_utils import read_client_data
from threading import Thread
import torch
import random
class Scaffold(Server):
def __init__(self, device, dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, time_threthold):
super().__init__(dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal,
time_threthold)
# select slow clients
self.set_slow_clients()
self.global_model=model
for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self.send_slow_clients):
train, test = read_client_data(dataset, i)
client = clientScaffold(device, i, train_slow, send_slow, train, test, model, batch_size, learning_rate, local_steps)
self.clients.append(client)
print(f"\nJoin clients / total clients: {self.join_clients} / {self.num_clients}")
self.server_controls = [torch.zeros_like(p.data) for p in model.parameters() if p.requires_grad]
def train(self):
for i in range(self.global_rounds+1):
self.send_parameters() #发送修正项 c
if i%self.eval_gap == 0:
print(f"\n-------------Round number: {i}-------------")
print("\nEvaluate global model")
self.evaluate()
self.selected_clients = self.select_clients()
for client in self.selected_clients:
client.train()
self.aggregate_parameters()
print("\nBest global results.")
self.print_(max(self.rs_test_acc), max(
self.rs_train_acc), min(self.rs_train_loss))
self.save_results()
self.save_global_model()
def send_parameters(self):
assert (len(self.clients) > 0)
for client in self.clients:
client.set_parameters(self.global_model)
for control, new_control in zip(client.server_controls, self.server_controls):
control.data = new_control.data
def aggregate_parameters(self):
assert (len(self.selected_clients) > 0)
active_clients = random.sample(
self.selected_clients, int((1-self.client_drop_rate) * self.join_clients))
active_train_samples = 0
for client in active_clients:
active_train_samples += client.train_samples
self.uploaded_weights = []
for client in active_clients:
self.uploaded_weights.append(client.train_samples / active_train_samples)
for user,w in zip(active_clients,self.uploaded_weights):
self.add_parameters(user, active_train_samples,w)
def add_parameters(self, user, total_samples,w):
num_of_selected_users = self.join_clients#len(self.selected_clients)
num_of_users = self.num_clients
num_of_samples = user.train_samples
for param, control, del_control, del_model in zip(self.global_model.parameters(), self.server_controls,
user.delta_controls, user.delta_model):
#因为数据不是独立同分布,所以采用每个客户端的样本比例来替代客户端数量
# param.data = param.data + del_model.data / num_of_selected_users
# control.data = control.data + del_control.data / num_of_users
param.data = param.data + del_model.data *w
control.data = control.data + del_control.data *w
clientscaffold:
import torch
import torch.nn as nn
from flcore.clients.clientbase import Client
import numpy as np
import time
import copy
from flcore.optimizers.fedoptimizer import *
import math
class clientScaffold(Client):
def __init__(self, device, numeric_id, train_slow, send_slow, train_data, test_data, model, batch_size, learning_rate,
local_steps):
super().__init__(device, numeric_id, train_slow, send_slow, train_data, test_data, model, batch_size, learning_rate,
local_steps)
self.loss = nn.CrossEntropyLoss()
#self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
L=0 # Regularization term 用的它默认值
#这里模型用的重写的优化器
self.optimizer = SCAFFOLDOptimizer(self.model.parameters(), lr=self.learning_rate, weight_decay=L)
self.controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.server_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.delta_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
#更新模型
self.delta_model = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.server_model = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
self.local_model = copy.deepcopy(list(self.model.parameters()))
def set_grads(self, new_grads):
if isinstance(new_grads, nn.Parameter):
for model_grad, new_grad in zip(self.model.parameters(), new_grads):
model_grad.data = new_grad.data
elif isinstance(new_grads, list):
for idx, model_grad in enumerate(self.model.parameters()):
model_grad.data = new_grads[idx]
def train(self):
start_time = time.time()
# self.model.to(self.device)
self.model.train()
#暂时用scaffold论文第2种方法更新本地的c
grads = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
#self.get_grads(grads)
max_local_steps = self.local_steps
if self.train_slow:
max_local_steps = np.random.randint(1, max_local_steps // 2)
for step in range(max_local_steps):
if self.train_slow:
time.sleep(0.1 * np.abs(np.random.rand()))
x, y = self.get_next_train_batch()
self.optimizer.zero_grad()
output = self.model(x)
loss = self.loss(output, y)
loss.backward()
self.optimizer.step(self.server_controls, self.controls)
# get model difference #得到当前和服务端模型的差异
for local, server, delta in zip(self.model.parameters(), self.server_model, self.delta_model):
delta.data = local.data.clone() - server.data.clone()
# get client new controls,对应论文的两种操作
new_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
opt = 2
if opt == 1:
for new_control, grad in zip(new_controls, grads):
new_control.data = grad.grad
if opt == 2:
for server_control, control, new_control, delta in zip(self.server_controls, self.controls, new_controls,
self.delta_model):
a = 1 / (math.ceil(self.train_samples / self.batch_size) * self.learning_rate)
new_control.data = control.data - server_control.data - delta.data * a
# get controls differences
for control, new_control, delta in zip(self.controls, new_controls, self.delta_controls):
delta.data = new_control.data - control.data
control.data = new_control.data
self.train_time_cost['num_rounds'] += 1
self.train_time_cost['total_cost'] += time.time() - start_time
def set_parameters(self, server_model):
for old_param, new_param, local_param, server_param in zip(self.model.parameters(), server_model.parameters(), self.local_model, self.server_model):
old_param.data = new_param.data.clone()
local_param.data = new_param.data.clone()
server_param.data = new_param.data.clone()