utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
'''
args包含怎么????
返回训练集和测试集
返回用户组:
dict类型{key:value}
key:用户的索引
value:这些用户的相应数据
'''
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''
一般用Compose把多个步骤整合到一起
transforms.ToTensor():
tensor是CHW,numpy是HWC
convert a PIL image to tensor (H*W*C) in
range [0,255] to a torch.Tensor(C*H*W)
in the range [0.0,1.0]
ToTensor()能够把灰度范围从0-255变换到0-1之间
transforms.Normalize用均值和标准差归一化张量图像
而后面的transform.Normalize()则把0-1变换到(-1,1)
对每个通道而言,Normalize执行以下操作:
image=(image-mean)/std
其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定
原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
transforms.Normalize(mean, std) 的计算公式:
input[channel] = (input[channel] - mean[channel]) / std[channel]
Normalize() 函数的作用是将数据转换为标准正态分布,使模型更容易收敛
'''
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
'''
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
1.root,表示cifar10数据的加载的相对目录
2.train,表示是否加载数据库的训练集,false的时候加载测试集
3.download,表示是否自动下载cifar数据集
4.transform,表示是否需要对数据进行预处理,none为不进行预处理
'''
# sample training data amongst users
# 采样训练数据
if args.iid:
# Sample IID user data from Cifar
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Cifar
'''
从Cifar数据集对用户采样non-IID数据
'''
if args.unequal:
# Chose uneuqal splits for every user
'''
对每个用户数据进行不平衡区分
'''
raise NotImplementedError()
'''
对于Cifar数据集,没有实现non-IID数据不平衡代码
'''
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
'''
用户之间数据noniid但是数据划分是均匀的
'''
elif args.dataset == 'mnist' or 'fmnist':
if args.dataset == 'mnist':
data_dir = '../data/mnist/'
else:
data_dir = '../data/fmnist/'
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
'''
用户之间数据non-iid,并且每个用户数据不平等划分
'''
else:
# Chose euqal splits for every user
user_groups = mnist_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
"""
Returns the average of the weights.
w应该是数组
"""
w_avg = copy.deepcopy(w[0])
'''
deepcopy函数:
test = torch.randn(4, 4)
print(test)
tensor([[ 1.8693, -0.3236, -0.3465, 0.9226],
[ 0.0369, -0.5064, 1.1233, -0.7962],
[-0.5229, 1.0592, 0.4072, -1.2498],
[ 0.2530, -0.4465, -0.8152, -0.9206]])
w = copy.deepcopy(test[0])
print(w)
tensor([ 1.8693, -0.3236, -0.3465, 0.9226])
'''
# print('++++++++')
# print(w)
# print('=====')
# print(w_avg)
# print('++++++++++++++++++')
# print(len(w)) == 10
# 这个函数接受的是list类型的local_weights
for key in w_avg.keys():
for i in range(1, len(w)):
# range(1, 10):1,2,3,4,5,6,7,8,9
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
'''
所有元素之和除以w的大小
w是什么类型来着???
'''
return w_avg
def exp_details(args):
print('\nExperimental details:')
print(f' Model : {args.model}')
print(f' Optimizer : {args.optimizer}')
print(f' Learning : {args.lr}')
print(f' Global Rounds : {args.epochs}\n')
print(' Federated parameters:')
if args.iid:
print(' IID')
else:
print(' Non-IID')
print(f' Fraction of users : {args.frac}')
print(f' Local Batch size : {args.local_bs}')
print(f' Local Epochs : {args.local_ep}\n')
'''
epoch:一个epoch指代所有的数据送入网络中完成一次前向
计算及反向传播的过程,由于一个epoch常常太大,我们会
将它分成几个较小的batches。
'''
return