Communication-Efficient Learning of Deep Networks from Decentralized Data
论文地址:[1602.05629] Communication-Efficient Learning of Deep Networks from Decentralized Data (arxiv.org)
1. 通信速率不稳定,且可能不可靠
2. 聚合服务器的容量有限,同时与server通信的client的数量受限
在FL的每一步考虑:
1. 减少client数量
2. 减少通信带宽
增加客户端计算,限制通信频率(在上传更新的梯度之前执行多次本地梯度下降迭代)
随机选择m个客户端采样,对这m个客户端的梯度更新进行平均以形成全局更新,同时用当前全局模型替换未采样的客户端
优点:相对于FedSGD在相同效果情况下,通讯成本大大降低
缺点:最终的模型是有偏倚的,不同于预期的每个客户端确定性聚合后的模型。
1. 在每一轮迭代的步骤t,服务端发送当前全局模型参数θ给客户端
2. 非抽样子集中的客户端根据θt,通过SGD更新本地模型
3. 抽样子集中每个客户端上传更新后的本地参数θt+1
4. 在迭代步骤t+1,服务端根据全局模型参数θi(t+1)计算出加权平均值θt+1
优化目标:
pi表示权重,一般表达式为nk/n。FedAvg算法最终取Li(θ)的加权平均值。
# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)
# copy weights
global_weights = global_model.state_dict()
# Training
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
for epoch in tqdm(range(args.epochs)):
local_weights, local_losses = [], []
print(f'\n | Global Training Round : {epoch+1} |\n')
global_model.train()
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset,
idxs=user_groups[idx], logger=logger)
w, loss = local_model.update_weights(
model=copy.deepcopy(global_model), global_round=epoch)
local_weights.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
# update global weights
global_weights = average_weights(local_weights)
# update global weights
global_model.load_state_dict(global_weights)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
# Calculate avg training accuracy over all users at every epoch
list_acc, list_loss = [], []
global_model.eval()
for c in range(args.num_users):
local_model = LocalUpdate(args=args, dataset=train_dataset,
idxs=user_groups[idx], logger=logger)
acc, loss = local_model.inference(model=global_model)
list_acc.append(acc)
list_loss.append(loss)
train_accuracy.append(sum(list_acc)/len(list_acc))
源代码地址:
Federated-Learning-PyTorch/federated_main.py at master · AshwinRJ/Federated-Learning-PyTorch · GitHub