【代码解析(2)】Communication-Efficient Learning of Deep Networks from Decentralized Data

utils.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid


def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    '''
        args包含怎么????
        返回训练集和测试集
        返回用户组:
            dict类型{key:value}
                key:用户的索引
                value:这些用户的相应数据
        
    '''
    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        '''
            一般用Compose把多个步骤整合到一起
            transforms.ToTensor():
                tensor是CHW,numpy是HWC
                convert a PIL image to tensor (H*W*C) in 
                range [0,255] to a torch.Tensor(C*H*W) 
                in the range [0.0,1.0]
                ToTensor()能够把灰度范围从0-255变换到0-1之间
            transforms.Normalize用均值和标准差归一化张量图像
                而后面的transform.Normalize()则把0-1变换到(-1,1)
                对每个通道而言,Normalize执行以下操作:
                image=(image-mean)/std
                其中mean和std分别通过(0.5,0.5,0.5)(0.5,0.5,0.5)进行指定
                原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
                
                transforms.Normalize(mean, std) 的计算公式:
                input[channel] = (input[channel] - mean[channel]) / std[channel]
            
            
            Normalize() 函数的作用是将数据转换为标准正态分布,使模型更容易收敛

        '''
        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                         transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                        transform=apply_transform)
        '''
            trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
            1.root,表示cifar10数据的加载的相对目录

            2.train,表示是否加载数据库的训练集,false的时候加载测试集

            3.download,表示是否自动下载cifar数据集

            4.transform,表示是否需要对数据进行预处理,none为不进行预处理
        '''
        # sample training data amongst users
        # 采样训练数据
        if args.iid:
            # Sample IID user data from Cifar
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Cifar
            '''
                从Cifar数据集对用户采样non-IID数据
            '''
            if args.unequal:
                # Chose uneuqal splits for every user
                '''
                    对每个用户数据进行不平衡区分
                '''
                raise NotImplementedError()
                '''
                    对于Cifar数据集,没有实现non-IID数据不平衡代码
                '''
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)
                '''
                    用户之间数据noniid但是数据划分是均匀的
                '''

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
                '''
                    用户之间数据non-iid,并且每个用户数据不平等划分
                '''
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups


def average_weights(w):
    """
    Returns the average of the weights.
    w应该是数组
    """
    w_avg = copy.deepcopy(w[0])
    '''
        deepcopy函数:
        test = torch.randn(4, 4)
        print(test)
            tensor([[ 1.8693, -0.3236, -0.3465,  0.9226],
            [ 0.0369, -0.5064,  1.1233, -0.7962],
            [-0.5229,  1.0592,  0.4072, -1.2498],
            [ 0.2530, -0.4465, -0.8152, -0.9206]])
        w = copy.deepcopy(test[0])
        print(w)
            tensor([ 1.8693, -0.3236, -0.3465,  0.9226])
    '''
    # print('++++++++')
    # print(w)
    # print('=====')
    # print(w_avg)
    # print('++++++++++++++++++')
    # print(len(w)) == 10
    # 这个函数接受的是list类型的local_weights
    for key in w_avg.keys():
        for i in range(1, len(w)):
            # range(1, 10):1,2,3,4,5,6,7,8,9
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
        '''
            所有元素之和除以w的大小
            w是什么类型来着???
        '''
    return w_avg


def exp_details(args):
    print('\nExperimental details:')
    print(f'    Model     : {args.model}')
    print(f'    Optimizer : {args.optimizer}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')

    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.local_ep}\n')
    '''
        epoch:一个epoch指代所有的数据送入网络中完成一次前向
        计算及反向传播的过程,由于一个epoch常常太大,我们会
        将它分成几个较小的batches。
    '''
    return

你可能感兴趣的:(Xidian科研,python,去中心化,深度学习,python)