Federated Learning with Non-IID Data

前言

一篇关于Federated Learning with Non-IID Data的论文,主要探讨了Federated Learning在每个clients的数据都是Non-IID的情况下,性能如何,以及如何改进。

Abstract

  • In this work, we focus on the statistical challenge of federated learning when local data is Non-IID.
  • The accuracy of federated learning reduces significantly, by up to ~55% for trains only on a single class of data.
  • This accuracy reduction can be explained by the weight divergence, which can be quantified by the earth mover’s distance(EMD)
  • As a solution, we propose a strategy to improve training on non-IID data by creating a small subset of data which is globally shared between all the edge devices.

FedAvg on Non-IID data

The training sets are evenly partitioned into 10 clients.

  • For IID setting, each client is randomly assigned a uniform distribution over 10 classes
  • For non-IID setting, the data is sorted by class and divided to create two extreme cases:
    • 1-class non-IID(each client receives data partition from only a single class)
    • 2-class non-IID(the sorted data is divided into 20 partitions and each client is randomly assigned 2 partitions from 2 classes(被分到两类的数据,但是每个类别的数据数量不相同))

一些重要变量:

  • B:batch size
  • E:the number of local epochs(将本地数据过几次)
  • decay rate
    epoch和iteration的区别

Experimental Results:
Federated Learning with Non-IID Data_第1张图片

  • 因为对于FedAvg来说有10个clients,所以每轮每个client训练的local data为100,总共每轮会训练1000个数据,所以对于SGD来说,为了要控制变量,所以B设为1000
  • 由图看出,Non-IID的程度越大,Test accuracy就越低
  • 更大的epoch(E=5)并不能提升Test accuracy

Federated Learning with Non-IID Data_第2张图片

Weight Divergence due to Non-IID Data

weight divergence:
在这里插入图片描述

  • The weight divergence of all the layers increases as the data become more non-IID, from IID to 2-class non-IID to 1-class non-IID
  • The root cause of the weight divergence is due to the distance between the data distribution on each client and the population distribution(总体分布)
  • Such distance can be evaluated with the earth mover’s distance(EMD) between the distributions

不同Non-IID程度的数据对weight divergence的影响:
Federated Learning with Non-IID Data_第3张图片
结论:可以看出在CNN的不同层中,Non-IID程度越大,weight divergence也越大

weight divergence vs EMD:

  • the bound of weight divergence is affected by EMD

Test accuracy vs EMD:

  • For all the three datasets, the test accuracy decreases with EMD

Federated Learning with Non-IID Data_第4张图片
Federated Learning with Non-IID Data_第5张图片

Proposed solution

  • we propose a data-sharing strategy to improve FedAvg with non-IID data by creating a small subset of data which is globally shared between all the edge devices.
  • we can distribute a small subset of global data containing a uniform distribution over classes from the cloud to the clients.(在初始阶段(initialization stage)时完成)
  • Instead of distributing a model with random weights, a warm-up model can be trained on the globally shared data and distributed to the clients.
  • Because the globally shared data can reduce EMD for the clients, the test accuracy is expected to improve

Federated Learning with Non-IID Data_第6张图片

Data-sharing Strategy

  • A globally shared dataset G that consists of a uniform distribution over classes is centralized in the cloud.
  • At the initialization stage of FedAvg, the warm-up model trained on G and a random α portion of G are distributed to each client
  • The local model of each client is trained on the shared data from G together with private data from each client
Two trade-offs
  • the trade-off between the test accuracy and the size of G, which is quantified as as:(D: total data from the clients)
    在这里插入图片描述
  • the trade-off between the test accuracy and α

Federated Learning with Non-IID Data_第7张图片

结论:

  • β约大(等价于G越大(即共享数据越大)),test accuracy约高;α越大(分配给每个结点的共享数据越多),test accuracy约高
  • the data-sharing strategy offers a solution for federated learning with non-IID data
  • the strategy only needs to be performed once when federated learning is initialized, so the communication cost is not a major concern
  • the globally shared data is a seperated dataset from the clients’ data so it is not privacy sensitive

Conclusion

  • the quality of model training degrades if each of the edge devices sees a unique distribution of data
  • as a solution, we propose a strategy to improve training on non-IID data by creating a small subset of data which is globally shared between all the edge devices

你可能感兴趣的:(边缘计算)