This benchmark dataset is aligned with the following publication: Federated Optimization in Heterogeneous Networks (https://arxiv.org/pdf/1812.06127.pdf). MLSys 2020.
数据集测试集解释(训练集相同):
sh stats.sh test
# result
####################################
DATASET: MNIST
1000 users
7371 samples (total)
7.37 samples per user (mean)
num_samples (std): 16.08
num_samples (std/mean): 2.18
num_samples (skewness): 8.20
num_sam num_users
0 927
20 40
40 19
60 8
80 1
100 0
120 1
140 2
160 0
180 1
数据集为字典格式,分别为num_samples(值为列表),users(值为列表),user_data(值为字典)。
其中user_data嵌套y和x,其key为users的值例如'f_00544',y为所拥有的各个图片的标签(只可能有两种,人为设定),x即为图片。
由以上结果可知拥有数据量为0~20的用户量为927个,60~80的有8个。
file_path='data/MNIST/test/all_data_0_niid_0_keep_10_test_9.json'
# cdata为json.load()之后的字典类型
len(cdata['user_data']['f_00913']['y'])
266 # 测试集中唯一一个数据量超过180的用户所拥有的数据量
以下是基于MNIST_non-iid的联邦平均的数据处理部分,难度主要在区分各种变量,而且涉及列表,字典,元祖,张量的不停变换,需要好好理清思路。
# load data
dataset = load_data(args, args.dataset)
def load_data(args, dataset_name):
# check if the centralized training is enabled
centralized = True if args.client_num_in_total == 1 else False
# 用户量对MNIST来说只有1和1000两种情况,似乎是和数据集有关
# 即使设置client_num_in_total为2
# check if the full-batch training is enabled
args_batch_size = args.batch_size
if args.batch_size <= 0:
full_batch = True
args.batch_size = 128 # 也就是说-1代表这128,如果超过128还是要分批
else:
full_batch = False
if dataset_name == "mnist":
logging.info("load_data. dataset_name = %s" % dataset_name)
client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
class_num = load_partition_data_mnist(args.batch_size)
"""
For shallow NN or linear models,
we uniformly sample a fraction of clients each round (as the original FedAvg paper)
"""
args.client_num_in_total = client_num
if centralized: # 把所有数据集中到一起
train_data_local_num_dict = {0: sum(user_train_data_num for user_train_data_num in train_data_local_num_dict.values())}
train_data_local_dict = {0: [batch for cid in sorted(train_data_local_dict.keys()) for batch in train_data_local_dict[cid]]} # 聚合所有的数据
test_data_local_dict = {0: [batch for cid in sorted(test_data_local_dict.keys()) for batch in test_data_local_dict[cid]]}
args.client_num_in_total = 1
if full_batch:
train_data_global = combine_batches(train_data_global)
test_data_global = combine_batches(test_data_global)
train_data_local_dict = {cid: combine_batches(train_data_local_dict[cid]) for cid in train_data_local_dict.keys()}
test_data_local_dict = {cid: combine_batches(test_data_local_dict[cid]) for cid in test_data_local_dict.keys()}
args.batch_size = args_batch_size # 变回-1??
dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
return dataset
def combine_batches(batches):
full_x = torch.from_numpy(np.asarray([])).float()
full_y = torch.from_numpy(np.asarray([])).long()
for (batched_x, batched_y) in batches:
full_x = torch.cat((full_x, batched_x), 0)
full_y = torch.cat((full_y, batched_y), 0)
return [(full_x, full_y)] # 变成张量的形式
def load_partition_data_mnist(batch_size):
train_path = "./../../../data/MNIST/train"
test_path = "./../../../data/MNIST/test"
users, groups, train_data, test_data = read_data(train_path, test_path)
if len(groups) == 0:
groups = [None for _ in users]
train_data_num = 0
test_data_num = 0
train_data_local_dict = dict()
test_data_local_dict = dict()
train_data_local_num_dict = dict()
train_data_global = list()
test_data_global = list()
client_idx = 0
for u, g in zip(users, groups):
user_train_data_num = len(train_data[u]['x']) # user本地训练数目
user_test_data_num = len(test_data[u]['x'])
train_data_num += user_train_data_num
test_data_num += user_test_data_num
train_data_local_num_dict[client_idx] = user_train_data_num
# transform to batches
train_batch = batch_data(train_data[u], batch_size) # 返回的一个列表,列表里面是元祖[([[28*28],[28*28]]),([1,1])]
test_batch = batch_data(test_data[u], batch_size)
# index using client index
train_data_local_dict[client_idx] = train_batch
test_data_local_dict[client_idx] = test_batch
train_data_global += train_batch
test_data_global += test_batch
client_idx += 1
client_num = client_idx
class_num = 10
return client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num
# 标量,标量,标量,列表,列表
# 字典,字典,字典,标量
def read_data(train_data_dir, test_data_dir):
clients = []
groups = []
train_data = {}
test_data = {}
train_files = os.listdir(train_data_dir)
train_files = [f for f in train_files if f.endswith('.json')] # get files
for f in train_files:
file_path = os.path.join(train_data_dir, f)
with open(file_path, 'r') as inf:
cdata = json.load(inf)
clients.extend(cdata['users'])
train_data.update(cdata['user_data'])
test_files = os.listdir(test_data_dir)
test_files = [f for f in test_files if f.endswith('.json')]
for f in test_files:
file_path = os.path.join(test_data_dir, f)
with open(file_path, 'r') as inf:
cdata = json.load(inf)
test_data.update(cdata['user_data'])
clients = list(sorted(train_data.keys()))
return clients, groups, train_data, test_data
def batch_data(data, batch_size):
'''
data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
returns x, y, which are both numpy array of length: batch_size
x的构成为[[28*28*],[28*28],[28*28]...],y的构成类似为为[1,0,1,0...]
'''
data_x = data['x']
data_y = data['y']
# randomly shuffle data
np.random.seed(100)
rng_state = np.random.get_state()
np.random.shuffle(data_x)
np.random.set_state(rng_state)
np.random.shuffle(data_y) # 保持打乱之后顺序一样
# loop through mini-batches
batch_data = list()
for i in range(0, len(data_x), batch_size): # -1默认切片大小为128
batched_x = data_x[i:i + batch_size]
batched_y = data_y[i:i + batch_size]
batched_x = torch.from_numpy(np.asarray(batched_x)).float()
batched_y = torch.from_numpy(np.asarray(batched_y)).long()
batch_data.append((batched_x, batched_y))
return batch_data