基于cifar10数据集的FedAvg联邦学习任务

本文目录

    • 1. cifar10数据集介绍
    • 2. 联邦学习介绍
    • 3. 客户端CNN模型
    • 3. FedAvg实现
    • 运行结果

出自论文《Communication-Efficient Learning of Deep Networks from Decentralized Data》

1. cifar10数据集介绍

CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

2. 联邦学习介绍

FL的数据应具有以下特性(标准):

  1. 训练来自移动设备的真实数据比数据中心提供的代理数据具有明显的优势;
  2. 数据是隐私敏感的或较大规模的,不需要仅出于训练模型的目的将其记录在数据中心;
  3. 对于监督任务,可以从用户交互中自然地推断出数据的标签。

数据是敏感的: 用户的照片或键盘输入的文本;

数据的分布也与代理数据提供的不同, 更有用户特点和优势;

数据的标签也是可以直接获得的:比如用户的照片和输入的文字等本身就是带标签的;照片可以通过用户的交互操作进行打标签(删除、分享、查看)。
联邦学习过程:

  1. 在每轮更新开始时,随机选择部分客户端,大小为C-fraction(应该是比例,C≤1);
  2. 然后,服务器将当前的全局算法的状态发送给这些客户(例如,当前的模型参数);
  3. 然后,每个客户端都基于全局状态及其本地数据集执行本地计算,并将更新发送到服务器;
  4. 最后,服务器将这些更新应用于其全局状态,然后重复该过程。
    基于cifar10数据集的FedAvg联邦学习任务_第1张图片

3. 客户端CNN模型

import torch.nn.functional as F
from torch import nn
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader


class CNN(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=n_kernels, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(in_channels=n_kernels, out_channels=2 * n_kernels, kernel_size=5)
        self.fc1 = nn.Linear(in_features=2 * n_kernels * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Client(object):
    def __int__(self, trainDataSet, dev):
        self.train_ds = trainDataSet
        self.dev = dev
        self.train_dl = None
        self.local_parameter = None

3. FedAvg实现

def evaluate(net, global_parameters, testDataLoader, dev):
    net.load_state_dict(global_parameters, strict=True)
    running_correct = 0
    running_samples = 0
    net.eval()
    # 载入测试集
    for data, label in testDataLoader:
        data, label = data.to(dev), label.to(dev)
        pred = net(data)
        running_correct += pred.argmax(1).eq(label).sum().item()
        running_samples += len(label)
    print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")


def local_upload(train_data_set, local_epoch, net, loss_fun, opt, global_parameters, dev):
    # 加载当前通信中最新全局参数
    net.load_state_dict(global_parameters, strict=True)
    # 设置迭代次数
    net.train()
    for epoch in range(local_epoch):
        for data, label in train_data_set:
            data, label = data.to(dev), label.to(dev)
            # 模型上传入数据
            predict = net(data)
            loss = loss_fun(predict, label)
            # 反向传播
            loss.backward()
            # 计算梯度,并更新梯度
            opt.step()
            # 将梯度归零,初始化梯度
            opt.zero_grad()
    # 返回当前Client基于自己的数据训练得到的新的模型参数
    return net.state_dict()


def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
          steps: int, node_iter: int, optim: str, lr: float, inner_lr: float,
          embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
          n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
          seed: int) -> None:
    ###############################
    # init nodes, hnet, local net #
    ###############################
    steps = 5
    node_iter = 5
    nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
                      batch_size=bs)
    net = CNN(n_kernels=n_kernels)
    # hnet = hnet.to(device)
    net = net.to(device)

    ##################
    # init optimizer #
    ##################
    # embed_lr = embed_lr if embed_lr is not None else lr
    optimizer = torch.optim.SGD(
        net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
    )
    criteria = torch.nn.CrossEntropyLoss()

    ################
    # init metrics #
    ################
    # step_iter = trange(steps)
    step_iter = range(steps)
    # train process
    # record  the global parameters
    global_parameters = {}
    for key, parameter in net.state_dict().items():
        global_parameters[key] = parameter.clone()
    for step in step_iter:

        local_parameters_list = {}
        # 需要训练的node数目
        for i in range(node_iter):
            # 随机选择一个客户端
            node_id = random.choice(range(num_nodes))
            # 用全局模型参数训练当前客户端
            local_parameters = local_upload(nodes.train_loaders[node_id], 5, net, criteria, optimizer,
                                            global_parameters, dev='cpu')
            print("\nEpoch: {}, Node Count: {}, Node ID: {}".format(step + 1, i + 1, node_id), end="")
            evaluate(net, local_parameters, nodes.val_loaders[node_id], 'cpu')
            local_parameters_list[i] = local_parameters

        # 更新当前轮次模型的参数
        sum_parameters = None
        for node_id, parameters in local_parameters_list.items():
            if sum_parameters is None:
                sum_parameters = parameters
            else:
                for key in parameters.keys():
                    sum_parameters[key] += parameters[key]
        for var in global_parameters:
            global_parameters[var] = (sum_parameters[var] / node_iter)
    # test
    net.load_state_dict(global_parameters, strict=True)
    net.eval()
    for data_set in nodes.test_loaders:
        running_correct = 0
        running_samples = 0
        for data, label in data_set:
            pred = net(data)
            running_correct += pred.argmax(1).eq(label).sum().item()
            running_samples += len(label)
        print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")

运行结果

基于cifar10数据集的FedAvg联邦学习任务_第2张图片

你可能感兴趣的:(联邦学习,pytorch,深度学习,python)