论文阅读:Cluster-driven Graph Federated Learning over Multiple Domains

Cluster-driven Graph Federated Learning over Multiple Domains

(1)动机:为了解决统计异质性,到目前为止,统计异质性已经用不同的方法来处理,但没有一种方法能模拟领域之间的直接知识共享。
(2)创新点:聚类用于修饰统计异质性,而图卷积网络(GCNs)则实现了不同域之间的知识共享,FedCG是第一个通过GCN来建模领域-领域交互的,GCN连接特定领域的模型组件。
(3)提出的模型:FEDCG
a.通过符合fl的聚类来标识域,并为每个域实例化特定于域的模块(剩余分支)
b.通过训练时的GCN连接特定领域的模块,学习领域之间的相互作用并共享知识
c.通过教师-学生分类器-训练迭代来学习无监督的聚类,并通过领域软分配分数来处理新的未见过的测试领域。
论文阅读:Cluster-driven Graph Federated Learning over Multiple Domains_第1张图片
服务器将模型fθ以及教师gϕ和学生gφ域分类器发送给客户端,在客户端,域分类器对局部数据x进行聚类,输出每个图像所属的域dˆ。在训练时,由gϕ预测硬标签dˆ,并通过知识蒸馏的过程作为输入训练gφ。在测试时,dˆ由gφ给出,是所发现域的加权组合。在FedCG中,网络fθ由域不可知部分(灰色)和残差域特定部分(蓝色)组成。特定领域的参数由GCN产生,接收输入A, W, V和dˆ。客户端k对fθ和gϕ的数据进行训练后,将更新后的权值θk和ϕ返回给服务器。在服务器端,通过FedAvg算法聚合更新。
notes:
a.聚类算法可以在测试时将看不见的数据分配给聚类,这多亏了域分类器。
b.由域的相似性的公式邻接矩阵的公式得,每个客户端不仅可以得到参数θ的集合,还可以得到邻接矩阵。通过这个定义,我们迫使特定于领域的组件的梯度通过GCN流向所有其他组件。通常,对特定领域组件的更新将影响所有特定领域的参数,甚至是当前培训回合中没有出现的领域。
(4)实验设置:
a.在测试时,新域可以作为已发现域(分别为摩天大楼and海洋)的软组合来处理,例如海上的摩天大楼。
b.给定一个输入图像,教师提供域伪标签作为目标,以重新精确学生的预测。特别地,我们通过迭代最小化教师和学生域预测之间的交叉熵损失来学习客户端学生参数。
c.用了两个数据集:一个是二分类(CelebA),一个是62种类别(FEMNIST),比较FedAvg在TensorFlow和我们的方法在PyTorch上的分类的准确性。
d.消融实验(重点分析在CelebA数据集上测试所提出模型的性能)
1——将标准的单一服务器模型替换为N个独立的特定于领域的模型,由于训练时的模型数据量不足,因此效果极差;
2——ReLU非线性应用于残差GCN输出的重要性
3----

你可能感兴趣的:(深度学习,人工智能,机器学习)