本联邦学习模型是基于论文 : Communication-Efficient Learning of Deep Networks from Decentralized Data.实现。看文章之前需要对该文章有了解,本篇文章主要以 code
为主
源码地址:FedAvg
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import torch
from tensorboardX import SummaryWriter
from options import args_parser
from update import LocalUpdate, test_inference
from model import CNNMnist
from utils import get_dataset, average_weights
if __name__ == '__main__':
start_time = time.time()
# 定义日志路径
path_project = os.path.abspath('..')
logger = SummaryWriter('../logs')
# 引入参数
args = args_parser()
# 选择训练的方式 CUDA or CPU
if args.gpu and torch.cuda.is_available():
device = 'cuda'
print(f'device is {device}')
else:
device = 'cpu'
print(f'device is {device}')
# 加载数据集和用户群组
train_dataset, test_dataset, user_group = get_dataset(args)
# 建立 CNN 模型
if args.model == 'cnn' and args.dataset == 'mnist':
global_model = CNNMnist(args=args)
else:
exit('没有适合的模型,需要创建一个模型')
# 为模型选择适合训练的设备
global_model.to(device)
# model.train()的作用是启用 Batch Normalization 和 Dropout
global_model.train()
print(global_model)
# 获取权重
global_weight = global_model.state_dict()
# 开始训练
train_loss, train_acc = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
for epoch in tqdm(range(args.epochs)):
local_weight, local_losses = [], []
print(f'\n global training round: {epoch + 1} | \n')
global_model.train()
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
# 跟据随机选择的客户端进行本地数据集的训练
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_group[idx], logger=logger)
w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
local_weight.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
# 更新全局权重
global_weight = average_weights(local_weight)
global_model.load_state_dict(global_weight)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
# 计算每个时期本地所有用户的平均训练准确度
list_acc, list_loss = [], []
global_model.eval()
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_group[idx], logger=logger)
acc, loss = local_model.inference(model=global_model)
list_acc.append(acc)
list_loss.append(loss)
train_acc.append(sum(list_acc) / len(list_acc))
# 打印每一个 every 'i' 之后的 全局训练的损失
if (epoch + 1) % print_every == 0:
print(f'\n avg training stats after {epoch + 1} global rounds: ')
print(f'training loss: {np.mean(np.array(train_loss))}')
print('Train Accuracy: {:.2f}% \n'.format(100 * train_acc[-1]))
# 训练完成后,进行测试
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print(f' \n Results after {args.epochs} global rounds of training:')
print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_acc[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))
# 保存对象 train_loss 和 train_accuracy:
file_name = './save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(
args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
)
with open(file_name, 'wb') as f:
pickle.dump([train_loss, train_acc], f)
print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))
matplotlib.use('Agg')
# Plot Loss curve
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(
args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
))
# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_acc)), train_acc, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(
args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
))
import argparse
def args_parser():
parser = argparse.ArgumentParser()
# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=10, help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=10, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.5, help='SGD momentum (default: 0.5)')
# model arguments
parser.add_argument('--model', type=str, default='cnn', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
help='comma-separated kernel size to use for convolution')
parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imgs")
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32,
help="number of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot.")
parser.add_argument('--max_pool', type=str, default='True',
help="Whether use max pooling rather than strided convolutions")
# other arguments
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--gpu', default=True, help="To use cuda, set to a specific GPU ID. Default set to use CPU.")
parser.add_argument('--optimizer', type=str, default='sgd', help="type of optimizer")
parser.add_argument('--iid', type=int, default=1, help='Default set to IID. Set to 0 for non-IID.')
parser.add_argument('--unequal', type=int, default=0,
help='whether to use unequal data splits for non-i.i.d setting (use 0 for equal splits)')
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
parser.add_argument('--verbose', type=int, default=1, help='verbose')
parser.add_argument('--seed', type=int, default=1, help='random seed')
args = parser.parse_args()
return args
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
class DatasetSplit(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class."""
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = [int(i) for i in idxs]
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return torch.tensor(image), torch.tensor(label)
class LocalUpdate(object):
def __init__(self, args, dataset, idxs, logger):
self.args = args
self.logger = logger
self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
self.device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
# Default criterion set to NLL loss function
self.criterion = nn.NLLLoss().to(self.device)
def train_val_test(self, dataset, idxs):
"""
Returns train, validation and test dataloaders for a given dataset and user indexes.
"""
# split indexes for train, validation, and test (80, 10, 10)
idxs_train = idxs[:int(0.8 * len(idxs))]
idxs_val = idxs[int(0.8 * len(idxs)):int(0.9 * len(idxs))]
idxs_test = idxs[int(0.9 * len(idxs)):]
trainloader = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
validloader = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val) / 10), shuffle=False)
testloader = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test) / 10), shuffle=False)
return trainloader, validloader, testloader
def update_weights(self, model, global_round):
# Set mode to train model
model.train()
epoch_loss = []
# Set optimizer for the local updates
if self.args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.5)
elif self.args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=1e-4)
# 本地模型运行
for local_epoch in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.trainloader):
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
log_probs = model(images)
loss = self.criterion(log_probs, labels)
loss.backward()
optimizer.step()
if self.args.verbose and (batch_idx % 10 == 0):
print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)] \t Loss: {:.6f}'.format(
global_round, local_epoch, batch_idx * len(images),
len(self.trainloader.dataset), 100. * batch_idx / len(self.trainloader), loss.item()))
self.logger.add_scalar('loss', loss.item())
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss) / len(batch_loss))
return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
# 准确度计算
def inference(self, model):
""" Returns the inference accuracy and loss."""
model.eval()
loss, total, correct = 0.0, 0.0, 0.0
for batch_idx, (images, labels) in enumerate(self.testloader):
images, labels = images.to(self.device), labels.to(self.device)
# Inference
outputs = model(images)
batch_loss = self.criterion(outputs, labels)
loss += batch_loss.item()
# Prediction
_, pred_labels = torch.max(outputs, 1)
pred_labels = pred_labels.view(-1)
correct += torch.sum(torch.eq(pred_labels, labels)).item()
total += len(labels)
accuracy = correct / total
return accuracy, loss
def test_inference(args, model, test_dataset):
""" Returns the test accuracy and loss. """
model.eval()
loss, total, correct = 0.0, 0.0, 0.0
device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
criterion = nn.NLLLoss().to(device)
testloader = DataLoader(test_dataset, batch_size=128, shuffle=False)
for batch_idx, (images, labels) in enumerate(testloader):
images, labels = images.to(device), labels.to(device)
# Inference
outputs = model(images)
batch_loss = criterion(outputs, labels)
loss += batch_loss.item()
# Prediction
_, pred_labels = torch.max(outputs, 1)
pred_labels = pred_labels.view(-1)
correct += torch.sum(torch.eq(pred_labels, labels)).item()
total += len(labels)
accuracy = correct / total
return accuracy, loss
from torch import nn
import torch.nn.functional as F
class CNNMnist(nn.Module):
def __init__(self, args):
super(CNNMnist, self).__init__()
self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, args.num_classes)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
import copy
import torch
from torchvision import datasets, transforms
def mnist_iid(dataset, num_users):
"""
Sample I.I.D. client data from MNIST dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset) / num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
if args.dataset == 'mnist':
data_dir = '../data/pytorch/'
apply_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
""" Returns the average of the weights."""
w_avg = copy.deepcopy(w[0])
for key in w_avg.keys():
for i in range(1, len(w)):
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
return w_avg
默认执行 option.py 里面的 default 参数 ,执行条件如下
python FedAvg.py
如果想自定义参数,比如使用 cpu
计算,增加运算的epoch
次数,可以这样执行
python FedAvg.py --epochs=15 --gpu=False
如果想了解源码可以点击这里:fedavg