联邦学习数据集划分Dirichlet划分法——pytorch实现

联邦学习数据集Dirichlet划分

做联邦学习数据集划分的时候,一般要考虑到数据的特异性,我们一般使用dirichlet分布来产生不同的客户端数据。
网上找的资料大部分都是numpy实现的dirichlet划分,但是因为强迫症 不想额外引入numpy,这里将介绍一下torch如何实现dirichlet划分的方法:
完整代码如下:
参数:

  1. train_labels: 数据集的标签列表
  2. dirichlet分布参数
  3. n_clients:有几个客户端需要分配

小小的解释一下:
首先我们使用Dirichlet函数返回了一个标签分布的矩阵tensor,这个tensor的维度是特征数X客户端数,每一行就是一个标签在不同客户端上的分布,总和为1。然后我们获得每一个标签的下标class_idcs。获得了这两个矩阵之后,我们只需要循环遍历每一个标签,就是:
for c, fracs in zip(class_idcs, label_distribution):
每次取出一个标签的下标位置,以及在每一个客户端的分配比例。通过分配比例拆分此标签,获得每一个客户端拥有的此标签的下标。

import torch
from torch.distributions.dirichlet import Dirichlet


def dirichlet_split_noniid(train_labels, alpha, n_clients):
    n_classes = train_labels.max() + 1
    label_distribution = Dirichlet(torch.full((n_clients,), alpha)).sample((n_classes,))
    # 1. Get the index of each label
    class_idcs = [torch.nonzero(train_labels == y).flatten()
                  for y in range(n_classes)]
    # 2. According to the distribution, the label is assigned to each client
    client_idcs = [[] for _ in range(n_clients)]

    for c, fracs in zip(class_idcs, label_distribution):
        total_size = len(c)
        splits = (fracs * total_size).int()
        splits[-1] = total_size - splits[:-1].sum()
        idcs = torch.split(c, splits.tolist())
        for i, idx in enumerate(idcs):
            client_idcs[i] += [idcs[i]]

    client_idcs = [torch.cat(idcs) for idcs in client_idcs]
    return client_idcs

唯一的坑点我觉得的就是,torch的split和numpy的split居然不一样。
numpy的split是按照累积和来分解,例如要将1~100分解成1 ~ 10, 10 ~ 50,numpy输入的数组是[10,50],而torch是按照实际大小,输入数组是[10, 40, 50]。不过个人感觉torch这个就直观很多,每一个拆分有多少数据。

你可能感兴趣的:(联邦学习,pytorch,人工智能,python,联邦学习)