联邦学习算法FedAvg实现(PyTorch)

联邦学习方法FedAvg实现(PyTorch)

通俗来讲,联邦学习(Federated Learning)结构由Server和若干Client组成,在联邦学习方法过程中,没有任何用户数据被发送到Server端,通过这种方式保护了用户的数据隐私。另外,通信中传输的参数是特定于改进当前模型的,因此一旦应用了他们,Server就没有理由存储它们,这进一步提高了安全性。

联邦学习的整体思路是“数据不动 模型动”。Server提供全局共享的模型,Client下载模型并训练自己的数据集,同时更新模型参数。在Server和Client的每一次通信中,Server将当前的模型参数分发给各个Client(或者说Client下载服务端的模型参数),经过Client的训练之后,将更新后的模型参数返回给Server,Server通过某种方法将聚合得到的N个模型参数融合成一个作为更新后的Server模型参数。以此循环。

本文实战联邦学习算法,梳理其方法流程,完成pytorch的代码实现。

数据集划分

首先我们要为每个客户端分配数据,在实际场景中,每个客户端有自己独有的数据,这里为了模拟场景,手动划分数据集给每个客户端。

客户端之间的数据可能是独立分布IID,也可能是非独立同分布Non_IID的

以Minist数据集为例——0~9的手写数字数据集,独立分布IID的意思是每个客户端都拥有0~9的完整数据集,而非独立同分布Non_IID就是每个客户端可能只拥有一部分数据集,比如说一个只拥有0、1的数据集,一个只拥有2、3的数据集

可以参考提升联邦学习的效率和效果 - 雪琪的文章 - 知乎 https://zhuanlan.zhihu.com/p/108163485

def cifar_iid(dataset, num_users):

    num_items = int(len(dataset)/num_users)
    # num_items = 500 # 测试时使用

    #print(num_items)
    dict_users, all_idxs = {
     }, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    print('cifar_iid is ok!')
    return dict_users

def cifar_noniid(dataset, num_users):
  
    num_items = int(len(dataset) / num_users) # 每个节点的图片总数
    # print(num_items)
    num_labels = 2  # 每个节点只包含两类图片
    num_pics = (int)(num_items / num_labels) # 每个节点每类所包含的图片总数
    # print(num_pics)
    dict_users, idxs, per_labal_idxs = {
     }, [i for i in range(10)], {
     }

    for i in range(10):
        per_labal_idxs[i] = [i for i in range(i * 5000,(i+1) * 5000)]

    for i in range(num_users):
        # print(idxs)
        random_labels = np.random.choice(idxs,2,replace=False)
        random_label_1 = set(np.random.choice(per_labal_idxs[random_labels[0]], num_pics, replace=False))
        per_labal_idxs[random_labels[0]] = list(set(per_labal_idxs[random_labels[0]]) - random_label_1)
        if(len(per_labal_idxs[random_labels[0]]) == 0):
            idxs.remove(random_labels[0])

        random_label_2 = set(np.random.choice(per_labal_idxs[random_labels[1]], num_pics, replace=False))
        per_labal_idxs[random_labels[1]] = list(set(per_labal_idxs[random_labels[1]]) - random_label_2)
        if (len(per_labal_idxs[random_labels[1]]) == 0):
            idxs.remove(random_labels[1])

        dict_users[i] = set.union(random_label_1, random_label_2)

    print('cifar_noniid is ok!')
    return dict_users

联邦训练

Server初始化并共享模型的参数

# 搭建神经网络,这里建立的是一个8层的CNN,cifar10的测试集的准确率为80%以上。
net_glob = CNNCifar(args=args).to(args.device)
net_glob.train()
# copy weights
# 获取模型参数以共享
w_glob = net_glob.state_dict() 

Server和Client通信

获取到共享的模型参数后,即可开始通信(本地的训练):

# epoch:rounds of training
for iter in range(args.epochs):
    w_locals, loss_locals = [], []
    m = max(int(args.frac * args.num_users), 1)
    # 随机选取一部分Client,全部选择会增加通信量,且实验效果可能不好
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    # 每个Client基于当前模型参数和自己的数据训练并更新模型,返回每个Client更新后的参数
    for idx in idxs_users:
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))
    # update global weights
    # 取平均值,得到本次通信中Server得到的更新后的模型参数   
    if(args.aggregate == 'avg'):
        w_glob = FedAvg(w_locals)
    elif(args.aggregate == 'median'):
        w_glob = FedMedian(w_locals)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
    loss_train.append(loss_avg)

local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])表示Client端的训练函数如下:

Client端训练

class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args  
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
      	# 在训练模型时会在前面加上
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad() # 将模型的参数梯度初始化为零
                log_probs = net(images) # 前向传播计算预测值(Server共享的模型)
                loss = self.loss_func(log_probs, labels) # 计算当前损失
                loss.backward() # 反向传播计算梯度
                optimizer.step() # 更新所有参数
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        # 返回当前Client基于自己的数据训练得到的新的模型参数
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

FedAvg算法

这算法就是做了一个简单的平均

def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.true_divide(w_avg[k], len(w))
    return w_avg

测试

训练结束之后,我们要通过测试集来验证方法的泛化性,注意,虽然训练时,Server没有得到过任何一条数据,但是联邦学习最终的目的还是要在Server端学习到一个鲁棒的模型,所以在做测试的时候,是在Server端进行的,如下:

def test_img(net_g, datatest, args):
 		# 在测试模型时在前面使用:
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs, shuffle=False)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        log_probs = net_g(data)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss

你可能感兴趣的:(联邦学习,pytorch,pytorch,深度学习)