联邦学习代码解读,超详细

参考文献:
[1]Brendan McMahan, H., Moore, E., Ramage, D., Hampson, S., and Agüera y Arcas, B., “Communication-Efficient Learning of Deep Networks from Decentralized Data”, arXiv e-prints, 2016.

参考代码:
https://github.com/AshwinRJ/Federated-Learning-PyTorch

用Pytorch开发项目的时候,常常将项目代码分为数据处理模块、模型构建模块与训练控制模块

联邦学习伪代码

联邦学习代码解读,超详细_第1张图片

主函数federated_main.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, average_weights, exp_details


if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()#命令行输入
    exp_details(args) #展示参数细节

    if args.gpu: #是否使用gpu
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # 加载数据集和用户组
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # 建立模型
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # 多层感知器
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    #设置模型进行训练,并且将其传输给设备
    global_model.to(device)
    global_model.train()
    print(global_model)

    # 复制权重
    global_weights = global_model.state_dict()

    # 训练
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)): #Tqdm 是一个快速,可扩展的Python进度条
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # 更新全局权重
        global_weights = average_weights(local_weights)

        #更新全局模型的权重
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        #通过计算所有用户在每一个回合中的平均训练精度进行计算
        
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc)/len(list_acc))

        # 在每'i'轮之后打印出全局训练损失
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

    # 在训练后测试验证
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # 保存目标训练损失和训练精度
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))

       #画图
    import  matplotlib
    import  matplotlib.pyplot as plt
    matplotlib.use('Agg')

    #绘制损失曲线
    plt.figure()
    plt.title('训练损失 vs 通信回合数')
    plt.plot(range(len(train_loss)), train_loss, color='r')
    plt.ylabel('训练损失')
    plt.xlabel('通信回合数')
    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
    format(args.dataset, args.model, args.epochs, args.frac,
           args.iid, args.local_ep, args.local_bs))

    #  Plot 平均准确性 vs 通信回合数
    plt.figure()
    plt.title('平均准确性 vs 通信回合数')
    plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
    plt.ylabel('平均准确性')
    plt.xlabel('通信回合数')
    plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
    format(args.dataset, args.model, args.epochs, args.frac,
            args.iid, args.local_ep, args.local_bs))


参数设置option.py

在代码文件夹option.pyargs_parser()当中,其中定义了包括

全局回合数、用户数量K、用户选取比例C、本地训回合数E、本地批量大小B、学习速率、SGD的动量为0.5(???)

模型参数:模型、核数量、核大小、通道数、归一化、过滤器数量、最大池化

以及一些其他参数和默认值

参数更新update.py

数据处理模块的主要任务:构建数据集。为方便深度学习项目构建数据集,Pytorch为我们提供了Dataset类。

构建数据集class DatasetSplit(Dataset)

先来看看Dataset类的官方解释:Dataset可以是任何东西,但它始终包含一个__len__函数(通过Python中的标准函数len调用)和一个用来索引到内容中的__getitem__函数。PyTorch官方中文文档

以下参考PyTorch如何构建数据集呢?

构建数据集前:明确需要哪些输入数据、训练时需要哪些数据

比如:现有1元和100元图像样本,分别放在两个文件中。我们的输入为图像数据,除了图像数据,还需要与图像数据相对应的类别标签,来计算损失loss。
在这里插入图片描述
明确了需要构建什么数据后,下一步就是通过继承Pytorch的dataset类来编写自己的dataset类。

定义了类DatasetSplit(Dataset)重构了Pytorch的类Dataset

class DatasetSplit(Dataset): #使用dataset重构
    "An abstract Dataset class warpped around Pytorch Dataset class."

    def __init__(self, dataset, idx):
        self.dataset = dataset
        self.idx = [int(i) for i in idx]

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, item):
        image, label = self.dataset[self.idx[item]]
        return torch.tensor(image), torch.tensor(label) #torch.tensor() #转换为张量形式,且会拷贝data

上面代码中,重写了__len__(self)方法。比较简单,返回数据列表长度,即数据集的样本数量。

__getitem__(self, item)方法中,通过dataset读取图像数据,最后返回下标为item的图像数据和标签的张量。

这里返回哪些数据主要是由训练代码中需要哪些数据来决定。也就是说,我们根据训练代码需要什么数据来重写__getitem__(self, index)方法并返回相应的数据。

本地更新模型构建模块calss LocalUpdate(object)

初始化 __init__(self, args, dataset, idxs, logger):

    def __init__(self, args, dataset, idxs, logger):
        self.args = args
        self.logger = logger
        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            dataset, list(idxs) ) #根据train_val_test划分训练、验证和测试数据集
        self.device = 'cuda' if args.gpu else 'cpu' #若args.gpu为true则在cuda上运行程序
        # Default criterion set to NLL loss function
        self.criterion = nn.NLLLoss().to(self.device) #交叉熵损失函数,用于描述系统的混乱程度,值越小,与真实样本越接近

数据集以及索引划分train_val_test(self, dataset, idxs)

输入:数据集、索引
输出:给定数据集的训练、验证、测试的记录器 用户索引

    def train_val_test(self, dataset, idxs):
        """
        return train, validation and test datalodaers for a given dataset and user indexes
        """
        #split indexes for train, validation, and test(80,10,10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9 * len(idxs)):]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=self.args.local_bs, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val) / 10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test) / 10), shuffle=False)
        return trainloader, validloader, testloader

本地权重更新upadate_weights(self, model, global_round)

输入:模型、全局更新回合数
输出:更新后的权重 、损失平均值

损失函数使用方法

    1 	optimizer = optim.SGD(model.parameters())
    2 	fot epoch in range(num_epoches):
    3 		train_loss=0
    4 		for step,(seq, label) in enumerate(data_loader):
    5 			# 损失函数
    6 			loss = criterion(model(seq), label.to(device))
    7 			# 将梯度清零
    8 	        opimizer.zero_grad()
    9 	        # 损失函数backward进行反向传播梯度的计算
   10 	        loss.backward()
   11 	        train_loss += loss.item()
   12 	        # 使用优化器的step函数来更新参数
   13 	        optimizer.step()
    def upadate_weights(self, model, global_round):
        #Set mode to train model
        model.train()
        epochs_loss=[]

        #set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                        momentum=0.5) #使用SGD作为优化器
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
                                         weight_decay=1e-4) #使用Adam作为优化器

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.strep()

                if self.args.verbose and (batch_idx %10 == 0):
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                                            100. * batch_idx / len(self.trainloader), loss.item()))
                self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
            epochs_loss.append(sum(batch_loss) / len(batch_loss))

        return model.state_dict(), sum(epochs_loss) / len(epochs_loss)

计算准确值以及损失值inference(self, model)

    def inference(self, model):
        " return the inference accuracy and loss"

        model.eval() #不改变权值样本训练
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)

            #inference
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            #prediction
            _, pred_labels = torch.max(outputs, 1) #返回输入tensor中所有元素的最大值
            pred_labels = pred_labels.view(-1) #view函数的作用为重构张量的维度,相当于numpy中resize()的功能
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

            accuracy = correct/total
            return accuracy, loss

应用集(获取数据集、权重取平均、展示细节)utils.py

获取数据集get_dataset(args)

输入:命令行参数
输出:用于训练和测试的数据集和用户组,其中键是索引,值是每个用户的相应数据

def get_dataset(args):
    "返回训练和测试数据集和用户组,用户用户组是字典,其中键是索引,值是每个用户的相应数据。"

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_tramsform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        #transforms.Compose()把多个步骤融合到一起
        #ToTensor()能够把灰度范围从0-255变换到0-1之间
        #而后面的transform.Normalize()则把0-1变换到(-1,1)

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                         transform=apply_tramsform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                        transform=apply_tramsform)

        #在用户中采集训练数据
        if args.iid:
            #从Mnist中采集IID用户数据
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            #从Mnist中采集Non-IID用户数据
            if args.unequal:
                #每个用户选择不平等划分
                raise NotImplementedError()
            else:
                #每个用户选择不平等划分
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        #在用户中采集训练数据
        if args.iid:
            #从Mnist中采集Non-IID用户数据
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            #从Mnist中采集Non-IID数据
            if args.unequal:
                #每个用户选择不平等划分
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                #每个用户选择不平等划分
                user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups

权重取平均average_weights(w)

def average_weights(w):
    "返回权重的平均值"

    w_avg = copy.deepcopy(w[0])
    #深复制函数深复制,就是从输入变量完全复刻一个相同的变量,无论怎么改变新变量,原有变量的值都不会受到影响。
    for key in w_avg.kerys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
            w_avg[key] = torch.div((w_avg[key], len(w)))
    return w_avg

细节输出exp_details(args)

def exp_details(args):
    print('\nExperimental details:')
    print(f'    模型     : {args.model}')
    print(f'    优化器 : {args.optimizer}')
    print(f'    学习速率  : {args.lr}')
    print(f'    全局回合数   : {args.epochs}\n')

    print('    联邦学习参数:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    用户百分比  : {args.frac}')
    print(f'    本地批次大小   : {args.local_bs}')
    print(f'    本地回合数       : {args.local_ep}\n')
    return

模型设置models.py

MLP多层感知器模型

MLP多层感知器
联邦学习代码解读,超详细_第2张图片

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.softmax(x)

CNN卷积神经网络

class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_channels)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
class CNNFashion_Mnist(nn.Module):
    def __init__(self, args):
        super(CNNFashion_Mnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2) )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7 * 7 * 32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_channels)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

modelC???

class modelC(nn.Module):
    def __init__(self, input_size, n_classes=10, **kwargs):
        super(AllConvNet, self).__init__()
        self.conv1 = nn.Conv2d(input_size, 96, 3, padding=1)
        self.conv2 = nn.Conv2d(96, 96, 3, padding=1)
        self.conv3 = nn.Conv2d(96, 96, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(96, 96, 3, padding=1)
        self.conv5 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv6 = nn.Conv2d(192, 192, 3, padding=1, stride=2)
        self.conv7 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv8 = nn.Conv2d(192, 192, 1)

        self.class_conv = nn.Conv2d(192, n_classes, 1)

        def forward(self, x):
            x_drop = F.dropout(x, .2)
            conv1_out = F.relu(self.conv1(x_drop))
            conv2_out = F.relu((self.conv2(conv1_out)))
            conv3_out = F.relu((self.conv2(conv2_out)))
            conv3_out_drop = F.dropout(conv3_out, .5)
            conv4_out = F.relu(self.conv4(conv3_out_drop))
            conv5_out = F.relu(self.conv5(conv4_out))
            conv6_out = F.relu(self.conv6(conv5_out))
            conv6_out_drop = F.dropout(conv6_out, .5)
            conv7_out = F.relu(self.conv7(conv6_out_drop))
            conv8_out = F.relu(self.conv8(conv7_out))

            class_out = F.relu(self.class_conv(conv8_out))
            pool_out = F.adaptive_avg_pool2d(class_out, 1)
            pool_out.squeeze_(-1)
            pool_out.squeeze_(-1)
            return pool_out

采样设置sampling.py

本代码文件包含几个函数:mnist_iid(dataset, num_users)mnist_noniid(dataset, num_users)mnist_noniid_unequal(dataset, num_users)
这几个函数都是从MNIST数据集中采集IID或者non-IID客户数据,每个客户都拥有相同的或者不相同的数据量

还有两个函数:cifar_iid(dataset, num_users)cifar_noniid(dataset, num_users)
这两个函数是从CIFAR10数据集中采集IID或者non-IID数据

你可能感兴趣的:(联邦学习,pytorch,深度学习,python,机器学习)