联邦学习non-iid用户数据采样
直接上代码
import os
from scipy.io import loadmat
import numpy as np
from collections import Counter
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from datasets.SequenceDatasets import dataset
from datasets.sequence_aug import *
from tqdm import tqdm
from options import args_parser
args = args_parser()
def get_files(root, N):
'''
This function is used to generate the final training set and test set.
root:The location of the data set
'''
dataname = {0: ["97.mat", "105.mat", "118.mat", "130.mat", "169.mat", "185.mat", "197.mat", "209.mat", "222.mat",
"234.mat"], # 1797rpm
1: ["98.mat", "106.mat", "119.mat", "131.mat", "170.mat", "186.mat", "198.mat", "210.mat", "223.mat",
"235.mat"], # 1772rpm
2: ["99.mat", "107.mat", "120.mat", "132.mat", "171.mat", "187.mat", "199.mat", "211.mat", "224.mat",
"236.mat"], # 1750rpm
3: ["100.mat", "108.mat", "121.mat", "133.mat", "172.mat", "188.mat", "200.mat", "212.mat", "225.mat",
"237.mat"]} # 1730rpm
datasetname = ["12k Drive End Bearing Fault Data", "12k Fan End Bearing Fault Data",
"48k Drive End Bearing Fault Data",
"Normal Baseline Data"]
label = [i for i in range(0, 10)]
data = [[] for _ in range(args.num_classes)]
lab = [[] for _ in range(args.num_classes)]
m = int(N[0])
num = []
for n in tqdm(range(len(dataname[m]))):
if n==0:
path1 =os.path.join(root,datasetname[3], dataname[m][n]).replace("\\", "/")
else:
path1 = os.path.join(root,datasetname[0], dataname[m][n]).replace("\\", "/")
data1, lab1 = data_load(path1,dataname[m][n],label=label[n])
# data += data1
# lab += lab1
data[n].append(data1)
lab[n].append(lab1)
num.append(len(data1))
return [data, lab],num
def data_load(filename, axisname, label):
'''
This function is mainly used to generate test data and training data.
filename:Data location
axisname:Select which channel's data,---->"_DE_time","_FE_time","_BA_time"
'''
signal_size = 1024
axis = ["_DE_time", "_FE_time", "_BA_time"]
datanumber = axisname.split(".")
if eval(datanumber[0]) < 100:
realaxis = "X0" + datanumber[0] + axis[0]
else:
realaxis = "X" + datanumber[0] + axis[0]
fl = loadmat(filename)[realaxis]
data = []
lab = []
start, end = 0, signal_size
while end <= fl.shape[0]:
data.append(fl[start:end])
lab.append(label)
start += signal_size
end += signal_size
return data, lab
def del_file(path_data):
for i in os.listdir(path_data) :
file_data = path_data + "\\" + i
if os.path.isfile(file_data) == True:
os.remove(file_data)
else:
del_file(file_data)
def get_dataset(args):
rootpath = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))).replace("\\", "/")
data_dir = os.path.join(rootpath,"data_preprocessing", "cwru").replace("\\", "/")
output_dir = os.path.join(rootpath,"data").replace("\\", "/")
os.makedirs(output_dir, exist_ok=True)
task = f'{args.data_number}'
print('task:', task)
normlizetype = 'mean - std'
delpath = os.path.join(rootpath,"data", args.dataset).replace("\\", "/")
del_file(delpath)
data_transforms = {
'train': Compose([
Reshape(),
Normalize(normlizetype),
# RandomAddGaussian(),
# RandomScale(),
# RandomStretch(),
# RandomCrop(),
Retype(),
# Scale(1)
]),
'val': Compose([
Reshape(),
Normalize(normlizetype),
Retype(),
# Scale(1)
])
}
#get source train and val
list_data,num = get_files(data_dir, task)
args.ways = int(args.num_classes * args.degree_noniid)
while True:
local_lb = []
names = locals()
for i in range(args.num_users):
population = args.num_classes
step = 1
sample = [element for element in range(1, population, step)]
names[f'loclb{i}'] = random.sample(range(0,args.num_classes),args.ways) #产生n--m之间的k个整数
local_lb.append(names[f'loclb{i}'])
locallb = [token for st in local_lb for token in st]
l = dict(Counter(locallb))
if len(l) == 10 :
break
# c1 = l[1]
# label_begin = {}
trclasslength = []
teclasslength = []
for i in range(args.num_classes):
# ll = l[i]
trlength = int(num[i] * 0.8 // l[i])
telength = int(num[i] * 0.2 // l[i])
trclasslength.append(trlength)
teclasslength.append(telength)
names = locals()
for j in range(len(list_data[0])):
# print(list_data[0][j])
names[f'tedata_pd{j}'] = pd.DataFrame({"data": list_data[0][j][0], "label": list_data[1][j][0]})
names[f'trdata_pd{j}'] = names[f'tedata_pd{j}'].iloc[:int(num[j]*0.8)]
names[f'tedata_pd{j}'].drop(names[f'tedata_pd{j}'].index[0:int(num[j]*0.8)], inplace=True)
# classes_list = []
userdatapath = []
names = locals()
for i in range(args.num_users):
trconcat = []
teconcat = []
for each_class in local_lb[i]:
names[f'train_pd{each_class}'] = names[f'trdata_pd{each_class}'].iloc[:trclasslength[each_class]]
names[f'trdata_pd{each_class}'].drop(names[f'trdata_pd{each_class}'].index[0:trclasslength[each_class]], inplace=True)
names[f'test_pd{each_class}'] = names[f'tedata_pd{each_class}'].iloc[:teclasslength[each_class]]
names[f'tedata_pd{each_class}'].drop(names[f'tedata_pd{each_class}'].index[0:teclasslength[each_class]], inplace=True)
# test_pd = train_test_split(names[f'data_pd{j}'], test_size=0.2, random_state=40)
trconcat.append(names[f'train_pd{each_class}'])
teconcat.append(names[f'test_pd{each_class}'])
names[f'trdf{i}'] = pd.concat(trconcat)
names[f'tedf{i}'] = pd.concat(teconcat)
xtrain = names[f'trdf{i}']['data'].values
ytrain = names[f'trdf{i}']['label'].values
xtest = names[f'tedf{i}']['data'].values
ytest = names[f'tedf{i}']['label'].values
dat_dict = dict()
# X_train = X_train.permute(0, 2, 1)
dat_dict["samples"] = torch.tensor([item for item in xtrain])
dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
dat_dict["labels"] = torch.from_numpy(ytrain)
dir = os.path.join(output_dir, args.dataset, f"{args.data_number}",f"user{i}").replace("\\", "/")
os.makedirs(dir, exist_ok=True)
torch.save(dat_dict, os.path.join(dir,"train.pt").replace("\\", "/"))
dat_dict = dict()
dat_dict["samples"] = torch.tensor([item for item in xtest])
dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
dat_dict["labels"] = torch.from_numpy(ytest)
os.makedirs(dir, exist_ok=True)
torch.save(dat_dict, os.path.join(dir,"test.pt").replace("\\", "/"))
realpath = os.path.join(rootpath,"data", args.dataset, f"{args.data_number}",f"user{i}").replace("\\", "/")
userdatapath.append(realpath)
return userdatapath,local_lb
if __name__ == '__main__':
args = args_parser()
get_dataset(args)
print('finish!')
torch.cuda.empty_cache()