CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
FL的数据应具有以下特性(标准):
数据是敏感的: 用户的照片或键盘输入的文本;
数据的分布也与代理数据提供的不同, 更有用户特点和优势;
数据的标签也是可以直接获得的:比如用户的照片和输入的文字等本身就是带标签的;照片可以通过用户的交互操作进行打标签(删除、分享、查看)。
联邦学习过程:
import torch.nn.functional as F
from torch import nn
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
class CNN(nn.Module):
def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=n_kernels, kernel_size=5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(in_channels=n_kernels, out_channels=2 * n_kernels, kernel_size=5)
self.fc1 = nn.Linear(in_features=2 * n_kernels * 5 * 5, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.fc3 = nn.Linear(in_features=84, out_features=out_dim)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class Client(object):
def __int__(self, trainDataSet, dev):
self.train_ds = trainDataSet
self.dev = dev
self.train_dl = None
self.local_parameter = None
def evaluate(net, global_parameters, testDataLoader, dev):
net.load_state_dict(global_parameters, strict=True)
running_correct = 0
running_samples = 0
net.eval()
# 载入测试集
for data, label in testDataLoader:
data, label = data.to(dev), label.to(dev)
pred = net(data)
running_correct += pred.argmax(1).eq(label).sum().item()
running_samples += len(label)
print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")
def local_upload(train_data_set, local_epoch, net, loss_fun, opt, global_parameters, dev):
# 加载当前通信中最新全局参数
net.load_state_dict(global_parameters, strict=True)
# 设置迭代次数
net.train()
for epoch in range(local_epoch):
for data, label in train_data_set:
data, label = data.to(dev), label.to(dev)
# 模型上传入数据
predict = net(data)
loss = loss_fun(predict, label)
# 反向传播
loss.backward()
# 计算梯度,并更新梯度
opt.step()
# 将梯度归零,初始化梯度
opt.zero_grad()
# 返回当前Client基于自己的数据训练得到的新的模型参数
return net.state_dict()
def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
steps: int, node_iter: int, optim: str, lr: float, inner_lr: float,
embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
seed: int) -> None:
###############################
# init nodes, hnet, local net #
###############################
steps = 5
node_iter = 5
nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
batch_size=bs)
net = CNN(n_kernels=n_kernels)
# hnet = hnet.to(device)
net = net.to(device)
##################
# init optimizer #
##################
# embed_lr = embed_lr if embed_lr is not None else lr
optimizer = torch.optim.SGD(
net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
)
criteria = torch.nn.CrossEntropyLoss()
################
# init metrics #
################
# step_iter = trange(steps)
step_iter = range(steps)
# train process
# record the global parameters
global_parameters = {}
for key, parameter in net.state_dict().items():
global_parameters[key] = parameter.clone()
for step in step_iter:
local_parameters_list = {}
# 需要训练的node数目
for i in range(node_iter):
# 随机选择一个客户端
node_id = random.choice(range(num_nodes))
# 用全局模型参数训练当前客户端
local_parameters = local_upload(nodes.train_loaders[node_id], 5, net, criteria, optimizer,
global_parameters, dev='cpu')
print("\nEpoch: {}, Node Count: {}, Node ID: {}".format(step + 1, i + 1, node_id), end="")
evaluate(net, local_parameters, nodes.val_loaders[node_id], 'cpu')
local_parameters_list[i] = local_parameters
# 更新当前轮次模型的参数
sum_parameters = None
for node_id, parameters in local_parameters_list.items():
if sum_parameters is None:
sum_parameters = parameters
else:
for key in parameters.keys():
sum_parameters[key] += parameters[key]
for var in global_parameters:
global_parameters[var] = (sum_parameters[var] / node_iter)
# test
net.load_state_dict(global_parameters, strict=True)
net.eval()
for data_set in nodes.test_loaders:
running_correct = 0
running_samples = 0
for data, label in data_set:
pred = net(data)
running_correct += pred.argmax(1).eq(label).sum().item()
running_samples += len(label)
print("\t" + 'accuracy: %.2f' % (running_correct / running_samples), end="")