client方的代码
import logging
import torch
from torch import nn
class Client:
def __init__(self,local_training_data,local_test_data,local_sample_number,args,device):
self.local_training_data=local_test_data
self.local_test_data=local_test_data
self.local_sample_number=local_sample_number
logging.info("self.local_sample_number = " + str(self.local_sample_number))
self.args=args
self.device=device
self.criterion=nn.CrossEntropyLoss().to(device)
def get_sample_number(self):
return self.local_sample_number
def train(self,net):
net.train()
optimizer=torch.optim.Adam(filter(lambda p:p.requires_grad,net.parameters()),lr=self.args.lr,
weight_decay=0.0001,amsgrad=True)
epoch_loss=[]
for epoch in range(self.args.epochs):
batch_loss=[]
for batch_idx,(images,labels) in enumerate(self.local_training_data):
images,labels=images.to(self.device),labels.to(self.device)
net.zero_grad()
log_probs=net(images)
loss=self.criterion(log_probs,labels)
loss.backward()
optimizer.step()
logging.info('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,batch_idx*len(images),len(self.local_training_data.dataset), 100. * batch_idx / len(self.local_training_data), loss.item()))
return net.cpu().state_dict(),sum(epoch_loss)/len(epoch)
def local_test(self,model_global,b_use_test_dataset=False):
model_global.eval()
model_global.to(self.device)
test_loss=test_acc=test_total=0
if b_use_test_dataset:
test_data=self.local_test_data
else:
test_data=self.local_training_data
with torch.no_grad():
for batch_idx,(x,target) in enumerate(test_data):
x=x.to(self.device)
target=target.to(self.device)
pred=model_global(x)
loss=self.criterion(pred,target)
_,predicted=torch.max(pred,1)
correct=predicted.eq(target).sum()
test_acc+=correct.item()
test_loss+=loss.item()*target.size(0)
test_total+=target.size(0)
return test_acc,test_total,test_loss
def global_test(self,model_global,global_test_data):
model_global.eval()
model_global.to(self.device)
test_loss=test_acc=test_total=0
with torch.no_grad():
for batch_idx,(x,target) in enumerate(global_test_data):
x=x.to(self.device)
target=target.to(self.device)
pred=model_global(x)
loss=self.criterion(pred,target)
_,predicted=torch.max(pred,1)
correct=predicted.eq(target).sum()
test_acc+=correct.item()
test_loss+=loss.item()*target.size(0)
test_total+=target.size(0)
return test_acc,test_total,test_loss
fedavg训练代码
import copy
import logging
import torch
import wandb
from torch import nn
from fedml_api.standalone.fedavg.client import Client
class FedAvgTrainer(object):
def __init__(self,dataset,model,device,args):
self.device=device
self.args=args
[train_data_num, test_data_num, train_data_global, test_data_global,
data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset
self.class_num = class_num
self.test_global=test_data_global
self.train_data_num=train_data_num
self.test_data_num=test_data_num
self.model_global=model
self.model_global.train()
self.client_list=[]
self.setup_clients(data_local_num_dict, train_data_local_dict, test_data_local_dict)
def setup_clients(self,data_local_num_dict,train_data_local_dict,test_data_local,dict):
logging.info("############setup_clients (START)#############")
for clinet_idx in range(self.args.client_number):
c=Client(train_data_local_dict[client_idx],test_data_local_dict[client_idx],
data_local_num_dict[clinet_idx],self.args,self.device)
self.client_list.append(c)
logging.info("############setup_clients (END)#############")
def train(self):
for round_idx in range(self.args.comm_round):
logging.info("Communication round : {}".format(round_idx))
self.model_global.train()
w_locals,loss_locals=[],[]
for idx,client in enumerate(self.client_list):
w,loss=client.train(net=copy.deepcopy(self.model_global))
w_locals.append((client.get_sample_number(),copy.deepcopy((w))))
loss_locals.append(copy.deepcopy(loss))
w_glob=self.aggregate(w_locals)
self.model_global.load_state_dict(w_glob)
loss_avg=sum(loss_locals)/len(loss_locals)
logging.info('Round {:3d}, Average loss {:.3f}'.format(round_idx, loss_avg))
self.local_test(self.model_global,round_idx)
def aggregate(self,w_locals):
logging.info("################aggregate: %d" % len(w_locals))
(num0,averaged_params)=w_locals[0]
for k in averaged_params.keys():
for i in range(0,len(w_locals)):
local_sample_number,local_model_params=w_locals[i]
w=local_sample_number/self.train_data_num
if i==0:
averaged_params[k]=local_model_params[k]*w
else:
averaged_params[k]+=local_model_params[k]*w
return averaged_params
def local_test(self,model_global,round_idx):
self.local_test_on_training_data(model_global,round_idx)
self.local_test_on_test_data(model_global,round_idx)
def local_test_on_training_data(self,model_global,round_idx):
num_samples=[]
tot_corrects=[]
losses=[]
for c in self.client_list:
tot_correct,num_sample,loss=c.local_test(model_global,False)
tot_correct.append(copy.deepcopy(tot_correct))
num_samples.append(copy.deepcopy(num_sample))
losses.append(copy.deepcopy(loss))
train_acc=sum(tot_corrects)/sum(num_samples)
train_loss=sum(losses)/sum(num_samples)
wandb.log({"Train/AccTop1": train_acc, "round": round_idx})
wandb.log({"Train/Loss": train_loss, "round": round_idx})
stats = {'training_acc': train_acc, 'training_loss': train_loss, 'num_samples': num_samples}
logging.info(stats)
def local_test_on_test_data(self,model_global,round_idx):
num_samples=[]
tot_corrects=[]
losses=[]
for c in self.client_list:
tot_correct,num_sample,loss=c.local_test(model_global,True)
tot_corrects.append(copy.deepcopy(tot_correct))
num_samples.append(copy.deepcopy(num_sample))
losses.append(copy.deepcopy(loss))
test_acc=sum(tot_corrects)/sum(num_samples)
test_loss=sum(losses)/sum(num_samples)
wandb.log({"Test/AccTop1": test_acc, "round": round_idx})
wandb.log({"Test/Loss": test_loss, "round": round_idx})
stats={'test_acc': test_acc, 'test_loss': test_loss, 'num_samples': num_samples}
logging.info(stats)
def global_test(self):
logging.info("################global_test")
acc_train,num_sample,loss_train=self.test_using_global_dataset(self.model_global,
self.train_global,
self.device)
acc_train=acc_train/num_sample
acc_test,num_sample,loss_test=self.test_using_global_dataset(self.model_global,
self.device)
acc_test=acc_test/num_sample
logging.info("Global Training Accuracy: {:.2f}".format(acc_train))
logging.info("Global Testing Accuracy: {:.2f}".format(acc_test))
wandb.log({"Global Training Accuracy": acc_train})
wandb.log({"Global Testing Accuracy": acc_test})
def test_using_global_dataset(self,model_global,global_test_device):
model_global.eval()
model_global.to(self.device)
test_loss=test_acc=test_total=0
criterion=nn.CrossEntropyLoss().to(self.device)
with torch.no_grad():
for batch_idx,(x,target) in enumerate(global_test_device):
x=x.to(self.device)
target=target.to(self.device)
pred=model_global(x)
loss=criterion(pred,target)
_,predicted=torch.max(pred,1)
correct=predicted.eq(target).sum()
test_acc+=correct.item()
test_loss+=loss.item()*target.size(0)
test_total+=target.size(0)
return test_acc,test_total,test_loss
主函数:
import argparse
import logging
import os
import sys
import numpy as np
import torch
import wandb
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_cifar100
from fedml_api.data_preprocessing.cinic10.data_loader import load_partition_data_cinic10
from fedml_api.model.deep_neural_networks.mobilenet import mobilenet
from fedml_api.model.deep_neural_networks.resnet import resnet56
from fedml_api.standalone.fedavg.fedavg_trainer import FedAvgTrainer
def add_args(parser):
parser.add_argument('--model',type=str,default='resnet56',metavar='N',
help='neural network used in training')
parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
help='dataset used for training')
parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10',
help='data directory')
parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
help='how to partition the dataset on local workers')
parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
help='partition alpha (default: 0.5)')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=5, metavar='EP',
help='how many epochs will be trained locally')
parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
help='the approximate fixed number of data points we will have on each local worker')
parser.add_argument('--client_number', type=int, default=4, metavar='NN',
help='number of workers in a distributed cluster')
parser.add_argument('--comm_round', type=int, default=10,
help='how many round of communications we shoud use')
parser.add_argument('--gpu', type=int, default=0,
help='gpu')
args = parser.parse_args()
return args
if __name__=="__main__":
logging.basicConfig()
logger=logging.getLogger()
logger.setLevel(logging.DEBUG)
args = add_args(argparse.ArgumentParser(description='FedAvg-standalone'))
device=torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu")
logger.info(device)
wandb.init(
project="fedml",
name="FedAVG-r" + str(args.comm_round) + "-e" + str(args.epochs) + "-lr" + str(args.lr),
config=args
)
np.random.seed(0)
torch.manual_seed(10)
# load data
data_loader = None
if args.dataset == "cifar10":
data_loader = load_partition_data_cifar10
elif args.dataset == "cifar100":
data_loader = load_partition_data_cifar100
elif args.dataset == "cinic10":
data_loader = load_partition_data_cinic10
else:
data_loader = load_partition_data_cifar10
train_data_num, test_data_num, train_data_global, test_data_global, \
data_local_num_dict, train_data_local_dict, test_data_local_dict, \
class_num = data_loader(args.dataset, args.data_dir, args.partition_method,
args.partition_alpha, args.client_number, args.batch_size)
dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
model = None
if args.model == "resnet56":
model = resnet56(class_num)
elif args.model == "mobilenet":
model = mobilenet(class_num=class_num)
trainer = FedAvgTrainer(dataset, model, device, args)
trainer.train()
trainer.global_test()