先部署何老师的fedml框架,然后了解每一部分的作用
我也是纯新手,希望研究联邦学习的朋友们看见这篇文可以联系我加好友一起学习
这个文件夹是代表联邦学习两大场景,“跨设备cross-device”,“跨孤岛cross-silo”,
跨设备是说整合大量移动端和边缘设备应用程序,移动键盘之类
跨孤岛是只涉及少量相对可靠二等客户端应用程序,例如多个组织合作训练一个模型,
首先是设置变量
def add_args(parser):
parser.add_argument(
"--model",
type=str,
default="mobilenet",
metavar="N",
help="neural network used in training",)
parser.add_argument(
"--data_parallel",
type=int,
default=0,
help="if distributed training"
)
parser.add_argument(
"--dataset",
type=str,
default="cifar10",
metavar="N",
help="dataset used for training",
)
parser.add_argument(
"--data_dir",
type=str,
default="./../../../data/cifar10",
help="data directory"
)
parser.add_argument(
"--partition_method",
type=str,
default="hetero",
metavar="N",
help="how to partition the dataset on local workers",
)
parser.add_argument(
"--partition_alpha",
type=float,
default=0.5,
metavar="PA",
help="partition alpha (default: 0.5)",
)
parser.add_argument(
"--client_num_in_total",
type=int,
default=1000,
metavar="NN",
help="number of workers in a distributed cluster",
)
parser.add_argument(
"--client_num_per_round",
type=int,
default=4,
metavar="NN",
help="number of workers",
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--client_optimizer",
type=str,
default="adam",
help="SGD with momentum; adam",
)
parser.add_argument(
"--lr",
type=float,
default=0.001,
metavar="LR",
help="learning rate (default: 0.001)",
)
parser.add_argument(
"--wd",
help="weight decay parameter;",
type=float,
default=0.0001,
)
parser.add_argument(
"--epochs",
type=int,
default=5,
metavar="EP",
help="how many epochs will be trained locally",
)
parser.add_argument(
"--comm_round",
type=int,
default=10,
help="how many round of communications we shoud use",
)
parser.add_argument(
"--is_mobile",
type=int,
default=0,
help="whether the program is running on the FedML-Mobile server side",
)
parser.add_argument(
"--frequency_of_train_acc_report",
type=int,
default=10,
help="the frequency of training accuracy report",
)
parser.add_argument(
"--frequency_of_test_acc_report",
type=int,
default=1,
help="the frequency of test accuracy report",
)
parser.add_argument(
"--gpu_sever_num",
type=int,
default=1,
help="gpu_server_num"
)
parser.add_argument(
"--gpu_num_per_sever",
type=int,
default=4,
help="gpu_num_per_server"
)
parser.add_argument(
"--ci",
type=int,
default=0,
help="CI",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
help="gpu",
)
parser.add_argument(
"--gpu_util",
type=str,
default="0",
help="gpu utils",
)
parser.add_argument(
"--local_rank",
type=int,
default=0,
help="given by torch.distributed.launch"
)
args=parser.parse_args()
return args
然后在看参数配置,设置dataloader
def load_data(args,dataset_name):
if dataset_name=="mnist":
logging.info("load_data.dataset_name=%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_mnist(args.batch_size)
args.client_num_in_total=client_num
elif dataset_name=="femnist":
logging.info("load_data.dataset_name==%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
trian_data_local_dict,
tets_data_local_dict,
class_num,
)=load_partition_data_federated_emnist(args.dataset,args.data_dir)
args.client_num_in_total=client_num
elif dataset_name=="shakespeare":
logging.info("load_data.dataset_name=%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_federated_shakespeare(args.batch_size)
args.client_num_in_total=client_num
elif dataset_name=="fed_shakespeare":
logging.info("load_data.dataset_name=%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_federated_shakespeare(args.dataset,args.batch_size)
args.client_num_in_total=client_num
elif dataset_name=="fed_cifar100":
logging.info("load_data.dataset_name=%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_federated_cifar100(args.dataset,args.batch_size)
args.client_num_in_total=client_num
elif dataset_name=="stackoverflow_lr":
logging.info("load_data.dataset_name=%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_federated_stackoverflow_lr(args.dataset,args.data_dir)
args.client_num_in_total=client_num
elif dataset_name=="stackoverflow_nwp":
logging.info("load_data.dataset_name=%s"%dataset_name)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_federated_stackoverflow_nwp(args.dataset,args.data_dir)
args.client_num_in_total=client_num
elif dataset_name in ["ILSVRC2012","ILSVRC2012_hdf5"]:
if args.data_parallel==1:
logging.info("load_data.dataset_name=%s"%dataset_name)
(
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=distributed_centralized_ImageNet_loader(
dataset=dataset_name,
data_dir=args.data_dir,
world_size=args.world_size,
rank=args.rank,
batch_size=args.batch_size,
)
else:
logging.info("load_data.dataswet_name=%s"%dataset_name)
(
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_ImageNet(
dataset=dataset_name,
data_dir=args.daat_dir,
partition_method=None,
partition_alpha=None,
client_num=args.client_num_in_total,
batch_size=args.batch_size,
)
elif dataset_name=="gld23k":
logging.info("load_data.dataset_name=%s"%dataset_name)
args.client_num_in_total=233
fed_train_map_file=os.path.join(
args.data_dir,"data_user_dict/gld23k_user_dict_train.csv"
)
fed_test_map_file=os.path.join(
args.data_dir,"data_user_dict/gld23k_user_dict_test.csv"
)
args.data_dir=os.path.join(args.data_dir,"images")
(
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_landmarks(
dataset=dataset_name,
data_dir=args.data_dir,
fed_train_map_file=fed_train_map_file,
fed_test_map_file=fed_test_map_file,
partition_method=None,
partition_alpha=None,
client_number=args.client_num_in_total,
batch_size=args.batch_size
)
elif dataset_name=="gld160k":
logging.info("load_data.data_name=%s"%dataset_name)
args.client_num_in_total=1262
fed_train_map_file=os.path.join(
args.data_dir,"data_user_dict/gld160k_user_dict_train.csv"
)
fed_test_map_file=os.path.join(
args.data_dir,"data_user_dict/gld160k_user_dict_test.csv"
)
args.data_dir=os.path.join(args.data_dir,"images")
(
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_landmarks(
dataset=dataset_name,
data_dir=args.data_dir,
fed_train_map_file=fed_train_map_file,
fed_test_map_file=fed_test_map_file,
partition_method=None,
partition_alpha=None,
client_number=args.client_num_in_total,
batch_size=args.batch_size,
)
else:
if dataset_name=="cifar10":
data_loader=load_partition_data_cifar100
elif dataset_name=="cifar100":
data_loader=load_partition_data_cifar100
elif dataset_name=="cinic10":
data_loader=load_partition_data_cinic10
else:
data_loader=load_partition_data_cifar10
(
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=data_loader(
args.dataset,
args.data_dir,
args.partitiom_method,
args.partition_alpha,
args.client_num_in_total,
args.batch_size,
)
dataset=[
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
]
return dataset
接下来是设置模型
def create_model(args,model_name,output_dim):
logging.info(
"create_model.model_name=%s,output_dim=%s"%(model_name,output_dim)
)
model=None
if model_name=="lr" and args.dataset=="mnist":
logging.info("LogisticRegression+MNIST")
model=LogisticRegression(28*28,output_dim)
elif model_name== "cnn" and args.dataset=="femni":
logging.info("CNN+FederatedEMNIST")
model=CNN_DropOut(False)
elif model_name=="resnet18_gn" and args.dataset=="fed_cifar100":
logging.info("ResNet18_GN+Federated_CIFAR100")
model=resnet18()
elif model_name=="rnn" and args.dataset=='shakespeare':
logging.info("RNN+shakespeare")
model=RNN_OriginalFedAvg()
elif model_name=="rnn" and args.dataset=="fed_shakespeare":
logging.info("RNN+fed_shakespeare")
mofdel=RNN_OriginalFedAvg()
elif model_name=="lr" and args.dataset=="stackoverflow_lr":
logging.info("lr+stackoverflow_lr")
model=LogisticRegression(10004,output_dim)
elif model_name=="rnn" and args.dataset=="stackoverflow_nwp":
logging.info("CNN+stackoverflow_nwp")
model=RNN_StackOverFlow()
elif model_name=="resnet56":
model=resnet56(class_num=output_dim)
elif model_name=="mobilenet":
model=mobilenet(class_num=output_dim)
elif model_name=="mobilenet_v3":
model=MobileNetV3(model_mode="LARGE",num_classes=output_dim)
elif model_name=="efficientnet":
efficientnet_dict={
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
}
model=EfficientNet.from_name(
model_name="efficientnet-b0",num_classes=output_dim
)
return model
接下来是主函数
if __name__=="__main__":
parser=argparse.ArgumentParser()
args=add_args(parser)
args.world_size=len(args.gpu_util.split(","))
worker_number=1
process_id=0
if args.data_parallel==1:
torch.distributed.init_process_group(backend="nccl",init_method="env://")
args.rank = torch.distributed.get_rank()
gpu_util = args.gpu_util.split(",")
gpu_util = [int(item.strip()) for item in gpu_util]
# device = torch.device("cuda", local_rank)
torch.cuda.set_device(gpu_util[args.rank])
process_id = args.rank
else:
args.rank = 0
logging.info(args)
str_process_name = "Fedml (single):" + str(process_id)
setproctitle.setproctitle(str_process_name)
logging.basicConfig(
level=logging.INFO,
# logging.basicConfig(level=logging.DEBUG,
format=str(process_id)
+ " - %(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
datefmt="%a, %d %b %Y %H:%M:%S",
)
hostname = socket.gethostname()
logging.info(
"#############process ID = "
+ str(process_id)
+ ", host name = "
+ hostname
+ "########"
+ ", process ID = "
+ str(os.getpid())
+ ", process Name = "
+ str(psutil.Process(os.getpid()))
)
# initialize the wandb machine learning experimental tracking platform (https://www.wandb.com/).
if process_id == 0:
wandb.init(
# project="federated_nas",
project="fedml",
name="Fedml (central)"
+ str(args.partition_method)
+ "r"
+ str(args.comm_round)
+ "-e"
+ str(args.epochs)
+ "-lr"
+ str(args.lr),
config=args,
)
# Set the random seed. The np.random seed determines the dataset partition.
# The torch_manual_seed determines the initial weight.
# We fix these two, so that we can reproduce the result.
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
logging.info("process_id = %d, size = %d" % (process_id, args.world_size))
# load data
dataset = load_data(args, args.dataset)
[
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
] = dataset
# create model.
# Note if the model is DNN (e.g., ResNet), the training will be very slow.
# In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
model = create_model(args, model_name=args.model, output_dim=dataset[7])
if args.data_parallel == 1:
device = torch.device("cuda:" + str(gpu_util[args.rank]))
model.to(device)
model = DistributedDataParallel(
model, device_ids=[gpu_util[args.rank]], output_device=gpu_util[args.rank]
)
else:
device = torch.device(
"cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
)
# start "federated averaging (FedAvg)"
single_trainer = CentralizedTrainer(dataset, model, device, args)
single_trainer.train()
1.mqtt_s3_fedavg_cifar10_resnet20_example
首先是dataset.py
import MNN
from torchvision.datasets import CIFAR10
F=MNN.expr
class Cifar10Dataset(MNN.data.dataset):
def __init__(self,training_dataset=True):
super(Cifar10Dataset,self).__init__()
self.is_training_dataset=training_dataset
trainset=CIFAR10(root="./data", train=True, download=True)
testset=CIFAR10(root="./data", train=False, download=True)
if self.is_training_dataset:
self.data=trainset.data.transpose(0,3,1,2)/255.0
self.labels=trainset.targets
else:
self.data=testset.data.transpose(0,3,1,2)/255.0
self.labels=testset.targets
def __getitem__(self, index):
dv=F.const(
self.data[index].flatten().tolist(),[3,32,32],F.data_format.NCHW
)
dl=F.const(
[self.labels[index]],[],F.data_format.NCHW,F.dtype.unit8
)
return [dv],[dl]
def __len__(self):
if self.is_training_dataset:
return 50000
else:
return 10000
然后是torch_sever.py,感觉和torch的编程很像
import MNN
import fedml
from fedml.cross_device import ServerMNN
from my_dataset import Cifar10Dataset
if __name__=="__main__":
args=fedml.init()
device=fedml.device.get_device(args)
train_dataset=Cifar10Dataset(True)
test_dataset=Cifar10Dataset(False)
train_loader=MNN.data.Dataloader(train_dataset,batch_size=64,shuffle=True)
test_loader=MNN.data.Dataloader(
test_dataset,batch_size=args.batch_size,shuffle=False
)
class_num=10
model=fedml.model.create(args,output_dim=class_num)
sever=ServerMNN(
args,device,test_loader,None
)
sever.run
采用yaml控制参数
common_args:
training_type:"cross_device"
using_mlops:false
random_seed:0
config_version:release
environment_args:
bootstrap: config / bootstrap.sh
data_args:
dataset: "cifar"
data_cache_dir: ~/fedml_data
partition_method: "hetero"
partition_alpha: 0.5
train_size: 10000
test_size: 5000
model_args:
model: "resnet20"
deeplearning_backend: "mnn"
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
global_model_file_path: "./model_file_cache/global_model.mnn"
train_args:
federated_optimizer: "FedAvg"
client_id_list: "[138]"
client_num_in_total: 1
client_num_per_round: 1
comm_round: 3
epochs: 1
batch_size: 100
client_optimizer: sgd
learning_rate: 0.03
weight_decay: 0.001
validation_args:
frequency_of_the_test: 5
device_args:
worker_num: 1 # this only reflects on the client number, not including the server
using_gpu: false
gpu_mapping_file: config/gpu_mapping.yaml
gpu_mapping_key: mapping_default
comm_args:
backend: "MQTT_S3_MNN"
mqtt_config_path: config/mqtt_config.yaml
s3_config_path: config/s3_config.yaml
tracking_args:
log_file_dir: ./log
enable_wandb: false
wandb_project: fedml
run_name: fedml_torch_fedavg_cifar_lr
dataset.py
import MNN
from torchvision.datasets import MNIST
F=MNN.expr
class MnistDataset(MNN.data.Dataset):
def __init__(self,training_dataset=True):
super(MnistDataset,self).__init__()
self.is_training_dataset=training_dataset
trainset=MNIST(root="./data", train=True, download=True)
testset = MNIST(root="./data", train=False, download=True)
if self.is_training_dataset:
self.data=trainset.data/255.0
self.labels=trainset.targets
else:
self.data=testset.data/255.0
self.labels=testset.targets
def __getitem__(self, index):
dv=F.const(
self.data[index].flatten().tolist(),
[1,28,28],
F.data_format.NCHW
)
dl=F.const([self.labels[index]],[],F.data_format.NCHW,F.dtype.unit8)
def __len__(self):
if self.is_training_dataset:
return 60000
else:
return 10000
import MNN
import fedml
from fedml.cross_device import ServerMNN
from my_dataset import MnistDataset
if __name__=="__main__":
args=fedml.init()
device=fedml.device.get_device(args)
train_dataset=MnistDataset(True)
test_dataset=MnistDataset(False)
train_dataloader=MNN.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
test_dataloader=MNN.data.DatatLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
class_num=10
model=fedml.model.create(args,output_dim=class_num)
sever=ServerMNN(
args,device,test_dataloader,None
)
sever.run()