基于Communication-Efficient Learning of Deep Networks from Decentralized Data文章的 FedAvg 模型复现

基于 PyTorch 语法复现FedAvg 模型

本联邦学习模型是基于论文 : Communication-Efficient Learning of Deep Networks from Decentralized Data.实现。看文章之前需要对该文章有了解,本篇文章主要以 code 为主
源码地址:FedAvg

FedAvg 算法思想

基于Communication-Efficient Learning of Deep Networks from Decentralized Data文章的 FedAvg 模型复现_第1张图片
FedAvg 大致思想如下:

  1. 服务端初始化一个权重参数,随机选择参与的客户端数量,广播给客户端
  2. 客户端获取初始化的权重参数,以及服务端选择的客户端,然后客户端在本地进行 n 轮训练,本地的每一轮训练都是以 batch 大小进行的训练, 对 n 轮结束以后得到的本地权重参数求平均值,传递给服务端
  3. 服务端收到来自客户端的权重,然后对客户端权重进行平均值求取更新服务端权重,再次传递给客户端进行下一轮的全局训练

代码实现

FedAvg.py 运行程序

import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from model import CNNMnist
from utils import get_dataset, average_weights


if __name__ == '__main__':
  start_time = time.time()
  
  # 定义日志路径
  path_project = os.path.abspath('..')
  logger = SummaryWriter('../logs')
  
  # 引入参数
  args = args_parser()

  # 选择训练的方式 CUDA or CPU
  if args.gpu and torch.cuda.is_available():
    device = 'cuda'
    print(f'device is {device}')
  else:
    device = 'cpu'
    print(f'device is {device}')
  
  # 加载数据集和用户群组
  train_dataset, test_dataset, user_group = get_dataset(args)
  
  # 建立 CNN 模型
  if args.model == 'cnn' and args.dataset == 'mnist':
      global_model = CNNMnist(args=args)
  else:
    exit('没有适合的模型,需要创建一个模型')
  
  # 为模型选择适合训练的设备
  global_model.to(device)
  
  # model.train()的作用是启用 Batch Normalization 和 Dropout
  global_model.train()
  print(global_model)
  
  # 获取权重
  global_weight = global_model.state_dict()
  
  # 开始训练
  train_loss, train_acc = [], []
  val_acc_list, net_list = [], []
  cv_loss, cv_acc = [], []
  print_every = 2
  val_loss_pre, counter = 0, 0
  
  for epoch in tqdm(range(args.epochs)):
    local_weight, local_losses = [], []
    print(f'\n global training round: {epoch + 1} | \n')
    
    global_model.train()
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    
    # 跟据随机选择的客户端进行本地数据集的训练
    for idx in idxs_users:
      local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_group[idx], logger=logger)
      w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
      local_weight.append(copy.deepcopy(w))
      local_losses.append(copy.deepcopy(loss))
    
    # 更新全局权重
    global_weight = average_weights(local_weight)
    global_model.load_state_dict(global_weight)
    
    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)
    
    # 计算每个时期本地所有用户的平均训练准确度
    list_acc, list_loss = [], []
    global_model.eval()
    for idx in idxs_users:
      local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_group[idx], logger=logger)
      acc, loss = local_model.inference(model=global_model)
      list_acc.append(acc)
      list_loss.append(loss)
    
    train_acc.append(sum(list_acc) / len(list_acc))
    
    # 打印每一个 every 'i' 之后的 全局训练的损失
    if (epoch + 1) % print_every == 0:
      print(f'\n avg training stats after {epoch + 1} global rounds: ')
      print(f'training loss: {np.mean(np.array(train_loss))}')
      print('Train Accuracy: {:.2f}% \n'.format(100 * train_acc[-1]))
  
  # 训练完成后,进行测试
  test_acc, test_loss = test_inference(args, global_model, test_dataset)
  
  print(f' \n Results after {args.epochs} global rounds of training:')
  print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_acc[-1]))
  print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))
  
  # 保存对象 train_loss 和 train_accuracy:
  file_name = './save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(
    args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
  )
  
  with open(file_name, 'wb') as f:
    pickle.dump([train_loss, train_acc], f)
  
  print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))
  
  matplotlib.use('Agg')
  # Plot Loss curve
  plt.figure()
  plt.title('Training Loss vs Communication rounds')
  plt.plot(range(len(train_loss)), train_loss, color='r')
  plt.ylabel('Training loss')
  plt.xlabel('Communication Rounds')
  plt.savefig('./save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(
    args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
  ))
  
  # Plot Average Accuracy vs Communication rounds
  plt.figure()
  plt.title('Average Accuracy vs Communication rounds')
  plt.plot(range(len(train_acc)), train_acc, color='k')
  plt.ylabel('Average Accuracy')
  plt.xlabel('Communication Rounds')
  plt.savefig('./save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(
    args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
  ))

options.py 运行时使用的一些参数

import argparse


def args_parser():
  parser = argparse.ArgumentParser()
  
  # federated arguments (Notation for the arguments followed from paper)
  parser.add_argument('--epochs', type=int, default=10, help="number of rounds of training")
  parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
  parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C')
  parser.add_argument('--local_ep', type=int, default=10, help="the number of local epochs: E")
  parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
  parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
  parser.add_argument('--momentum', type=float, default=0.5, help='SGD momentum (default: 0.5)')
  
  # model arguments
  parser.add_argument('--model', type=str, default='cnn', help='model name')
  parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
  parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                      help='comma-separated kernel size to use for convolution')
  parser.add_argument('--num_channels', type=int, default=1, help="number  of channels of imgs")
  parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
  parser.add_argument('--num_filters', type=int, default=32,
                      help="number of filters for conv nets -- 32 for  mini-imagenet, 64 for omiglot.")
  parser.add_argument('--max_pool', type=str, default='True',
                      help="Whether use max pooling rather than strided convolutions")
  
  # other arguments
  parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
  parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
  parser.add_argument('--gpu', default=True, help="To use cuda, set to a specific GPU ID. Default set to use CPU.")
  parser.add_argument('--optimizer', type=str, default='sgd', help="type of optimizer")
  parser.add_argument('--iid', type=int, default=1, help='Default set to IID. Set to 0 for non-IID.')
  parser.add_argument('--unequal', type=int, default=0,
                      help='whether to use unequal data splits for non-i.i.d setting (use 0 for equal splits)')
  parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
  parser.add_argument('--verbose', type=int, default=1, help='verbose')
  parser.add_argument('--seed', type=int, default=1, help='random seed')
  args = parser.parse_args()
  
  return args

update.py 参数更新的部分

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset


class DatasetSplit(Dataset):
  """An abstract Dataset class wrapped around Pytorch Dataset class."""
  
  def __init__(self, dataset, idxs):
    self.dataset = dataset
    self.idxs = [int(i) for i in idxs]
  
  def __len__(self):
    return len(self.idxs)
  
  def __getitem__(self, item):
    image, label = self.dataset[self.idxs[item]]
    return torch.tensor(image), torch.tensor(label)


class LocalUpdate(object):
  def __init__(self, args, dataset, idxs, logger):
    self.args = args
    self.logger = logger
    self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
    self.device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
    # Default criterion set to NLL loss function
    self.criterion = nn.NLLLoss().to(self.device)
  
  def train_val_test(self, dataset, idxs):
    """
    Returns train, validation and test dataloaders for a given dataset and user indexes.
    """
    # split indexes for train, validation, and test (80, 10, 10)
    idxs_train = idxs[:int(0.8 * len(idxs))]
    idxs_val = idxs[int(0.8 * len(idxs)):int(0.9 * len(idxs))]
    idxs_test = idxs[int(0.9 * len(idxs)):]
    
    trainloader = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
    validloader = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val) / 10), shuffle=False)
    testloader = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test) / 10), shuffle=False)
    return trainloader, validloader, testloader
  
  def update_weights(self, model, global_round):
    # Set mode to train model
    model.train()
    epoch_loss = []
    
    # Set optimizer for the local updates
    if self.args.optimizer == 'sgd':
      optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.5)
    elif self.args.optimizer == 'adam':
      optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=1e-4)
    
    # 本地模型运行
    for local_epoch in range(self.args.local_ep):
      batch_loss = []
      for batch_idx, (images, labels) in enumerate(self.trainloader):
        images, labels = images.to(self.device), labels.to(self.device)
        
        model.zero_grad()
        log_probs = model(images)
        loss = self.criterion(log_probs, labels)
        loss.backward()
        optimizer.step()
        
        if self.args.verbose and (batch_idx % 10 == 0):
          print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)] \t Loss: {:.6f}'.format(
            global_round, local_epoch, batch_idx * len(images),
            len(self.trainloader.dataset), 100. * batch_idx / len(self.trainloader), loss.item()))
        self.logger.add_scalar('loss', loss.item())
        batch_loss.append(loss.item())
      epoch_loss.append(sum(batch_loss) / len(batch_loss))
    
    return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
  
  # 准确度计算
  def inference(self, model):
    """ Returns the inference accuracy and loss."""
    
    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    
    for batch_idx, (images, labels) in enumerate(self.testloader):
      images, labels = images.to(self.device), labels.to(self.device)
      
      # Inference
      outputs = model(images)
      batch_loss = self.criterion(outputs, labels)
      loss += batch_loss.item()
      
      # Prediction
      _, pred_labels = torch.max(outputs, 1)
      pred_labels = pred_labels.view(-1)
      correct += torch.sum(torch.eq(pred_labels, labels)).item()
      total += len(labels)
    
    accuracy = correct / total
    return accuracy, loss


def test_inference(args, model, test_dataset):
  """ Returns the test accuracy and loss. """
  
  model.eval()
  loss, total, correct = 0.0, 0.0, 0.0
  
  device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
  
  criterion = nn.NLLLoss().to(device)
  testloader = DataLoader(test_dataset, batch_size=128, shuffle=False)
  
  for batch_idx, (images, labels) in enumerate(testloader):
    images, labels = images.to(device), labels.to(device)
    
    # Inference
    outputs = model(images)
    batch_loss = criterion(outputs, labels)
    loss += batch_loss.item()
    
    # Prediction
    _, pred_labels = torch.max(outputs, 1)
    pred_labels = pred_labels.view(-1)
    correct += torch.sum(torch.eq(pred_labels, labels)).item()
    total += len(labels)
  
  accuracy = correct / total
  return accuracy, loss

model.py 模型设计部分部分

from torch import nn
import torch.nn.functional as F

class CNNMnist(nn.Module):
  def __init__(self, args):
    super(CNNMnist, self).__init__()
    self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, args.num_classes)
  
  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

utils.py 与数据有关的函数:获取数据,格式化数据,平均权重

import copy
import torch
from torchvision import datasets, transforms

def mnist_iid(dataset, num_users):
  """
  Sample I.I.D. client data from MNIST dataset
  :param dataset:
  :param num_users:
  :return: dict of image index
  """
  num_items = int(len(dataset) / num_users)
  dict_users, all_idxs = {}, [i for i in range(len(dataset))]
  for i in range(num_users):
    dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
    all_idxs = list(set(all_idxs) - dict_users[i])
  return dict_users

def get_dataset(args):
  """ Returns train and test datasets and a user group which is a dict where
  the keys are the user index and the values are the corresponding data for
  each of those users.
  """
  if args.dataset == 'mnist':
    data_dir = '../data/pytorch/'
    
    apply_transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    
    train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=apply_transform)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=apply_transform)
    
    # sample training data amongst users
    if args.iid:
      # Sample IID user data from Mnist
      user_groups = mnist_iid(train_dataset, args.num_users)
      
  return train_dataset, test_dataset, user_groups


def average_weights(w):
  """ Returns the average of the weights."""
  w_avg = copy.deepcopy(w[0])
  for key in w_avg.keys():
    for i in range(1, len(w)):
      w_avg[key] += w[i][key]
    w_avg[key] = torch.div(w_avg[key], len(w))
  return w_avg

运行

默认执行 option.py 里面的 default 参数 ,执行条件如下

python FedAvg.py 

如果想自定义参数,比如使用 cpu 计算,增加运算的epoch次数,可以这样执行

python FedAvg.py  --epochs=15  --gpu=False

如果想了解源码可以点击这里:fedavg

你可能感兴趣的:(联邦学习,pytorch,人工智能,分布式)