FedML数据处理(MNIST_non-iid数据集)

关于数据集MNIST_non-iid

This benchmark dataset is aligned with the following publication: Federated Optimization in Heterogeneous Networks (https://arxiv.org/pdf/1812.06127.pdf). MLSys 2020.

数据集测试集解释(训练集相同):

sh stats.sh test

# result
####################################
DATASET: MNIST
1000 users
7371 samples (total)
7.37 samples per user (mean)
num_samples (std): 16.08
num_samples (std/mean): 2.18
num_samples (skewness): 8.20

num_sam num_users
0        927
20       40
40       19
60       8
80       1
100      0
120      1
140      2
160      0
180      1

数据集为字典格式,分别为num_samples(值为列表),users(值为列表),user_data(值为字典)。

其中user_data嵌套y和x,其key为users的值例如'f_00544',y为所拥有的各个图片的标签(只可能有两种,人为设定),x即为图片。

由以上结果可知拥有数据量为0~20的用户量为927个,60~80的有8个。

file_path='data/MNIST/test/all_data_0_niid_0_keep_10_test_9.json'
# cdata为json.load()之后的字典类型
len(cdata['user_data']['f_00913']['y'])
266  # 测试集中唯一一个数据量超过180的用户所拥有的数据量

以下是基于MNIST_non-iid的联邦平均的数据处理部分,难度主要在区分各种变量,而且涉及列表,字典,元祖,张量的不停变换,需要好好理清思路。

# load data
dataset = load_data(args, args.dataset)

def load_data(args, dataset_name):
    # check if the centralized training is enabled
    centralized = True if args.client_num_in_total == 1 else False
    # 用户量对MNIST来说只有1和1000两种情况,似乎是和数据集有关
    # 即使设置client_num_in_total为2 

    # check if the full-batch training is enabled
    args_batch_size = args.batch_size
    if args.batch_size <= 0:
        full_batch = True
        args.batch_size = 128  # 也就是说-1代表这128,如果超过128还是要分批
    else:
        full_batch = False

    if dataset_name == "mnist":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_mnist(args.batch_size)
        """
        For shallow NN or linear models, 
        we uniformly sample a fraction of clients each round (as the original FedAvg paper)
        """
        args.client_num_in_total = client_num

    if centralized:  # 把所有数据集中到一起
        train_data_local_num_dict = {0: sum(user_train_data_num for user_train_data_num in train_data_local_num_dict.values())}
        train_data_local_dict = {0: [batch for cid in sorted(train_data_local_dict.keys()) for batch in train_data_local_dict[cid]]}  # 聚合所有的数据
        test_data_local_dict = {0: [batch for cid in sorted(test_data_local_dict.keys()) for batch in test_data_local_dict[cid]]}
        args.client_num_in_total = 1

    if full_batch:
        train_data_global = combine_batches(train_data_global)
        test_data_global = combine_batches(test_data_global)
        train_data_local_dict = {cid: combine_batches(train_data_local_dict[cid]) for cid in train_data_local_dict.keys()}
        test_data_local_dict = {cid: combine_batches(test_data_local_dict[cid]) for cid in test_data_local_dict.keys()}
        args.batch_size = args_batch_size  # 变回-1??

    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset

def combine_batches(batches):
    full_x = torch.from_numpy(np.asarray([])).float()  
    full_y = torch.from_numpy(np.asarray([])).long()
    for (batched_x, batched_y) in batches:
        full_x = torch.cat((full_x, batched_x), 0)
        full_y = torch.cat((full_y, batched_y), 0)
    return [(full_x, full_y)]  # 变成张量的形式
def load_partition_data_mnist(batch_size):
    train_path = "./../../../data/MNIST/train"
    test_path = "./../../../data/MNIST/test"
    users, groups, train_data, test_data = read_data(train_path, test_path)

    if len(groups) == 0:
        groups = [None for _ in users]
    train_data_num = 0
    test_data_num = 0
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    train_data_local_num_dict = dict()
    train_data_global = list()
    test_data_global = list()
    client_idx = 0
    for u, g in zip(users, groups):
        user_train_data_num = len(train_data[u]['x'])  # user本地训练数目
        user_test_data_num = len(test_data[u]['x'])
        train_data_num += user_train_data_num
        test_data_num += user_test_data_num
        train_data_local_num_dict[client_idx] = user_train_data_num

        # transform to batches
        train_batch = batch_data(train_data[u], batch_size)  # 返回的一个列表,列表里面是元祖[([[28*28],[28*28]]),([1,1])]
        test_batch = batch_data(test_data[u], batch_size)

        # index using client index
        train_data_local_dict[client_idx] = train_batch
        test_data_local_dict[client_idx] = test_batch
        train_data_global += train_batch
        test_data_global += test_batch
        client_idx += 1
    client_num = client_idx
    class_num = 10

    return client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
           train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num
    # 标量,标量,标量,列表,列表
    # 字典,字典,字典,标量
def read_data(train_data_dir, test_data_dir):
    clients = []
    groups = []
    train_data = {}
    test_data = {}

    train_files = os.listdir(train_data_dir)
    train_files = [f for f in train_files if f.endswith('.json')]  # get files
    for f in train_files:
        file_path = os.path.join(train_data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        train_data.update(cdata['user_data'])

    test_files = os.listdir(test_data_dir)
    test_files = [f for f in test_files if f.endswith('.json')]
    for f in test_files:
        file_path = os.path.join(test_data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        test_data.update(cdata['user_data'])

    clients = list(sorted(train_data.keys()))

    return clients, groups, train_data, test_data
def batch_data(data, batch_size):
    '''
    data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
    returns x, y, which are both numpy array of length: batch_size
    x的构成为[[28*28*],[28*28],[28*28]...],y的构成类似为为[1,0,1,0...]
    '''
    data_x = data['x']
    data_y = data['y']

    # randomly shuffle data
    np.random.seed(100)
    rng_state = np.random.get_state()
    np.random.shuffle(data_x)
    np.random.set_state(rng_state)
    np.random.shuffle(data_y)  # 保持打乱之后顺序一样

    # loop through mini-batches
    batch_data = list()
    for i in range(0, len(data_x), batch_size):   # -1默认切片大小为128
        batched_x = data_x[i:i + batch_size]
        batched_y = data_y[i:i + batch_size]
        batched_x = torch.from_numpy(np.asarray(batched_x)).float()
        batched_y = torch.from_numpy(np.asarray(batched_y)).long()
        batch_data.append((batched_x, batched_y))
    return batch_data

 

你可能感兴趣的:(联邦学习实战,数据挖掘,pytorch)