fedml编程

client方的代码

import logging
import torch
from torch import nn

class Client:
    def __init__(self,local_training_data,local_test_data,local_sample_number,args,device):
        self.local_training_data=local_test_data
        self.local_test_data=local_test_data
        self.local_sample_number=local_sample_number
        logging.info("self.local_sample_number = " + str(self.local_sample_number))
        self.args=args
        self.device=device
        self.criterion=nn.CrossEntropyLoss().to(device)
    def get_sample_number(self):
        return self.local_sample_number
    def train(self,net):
        net.train()
        optimizer=torch.optim.Adam(filter(lambda p:p.requires_grad,net.parameters()),lr=self.args.lr,
                                   weight_decay=0.0001,amsgrad=True)
        epoch_loss=[]
        for epoch in range(self.args.epochs):
            batch_loss=[]
            for batch_idx,(images,labels) in enumerate(self.local_training_data):
                images,labels=images.to(self.device),labels.to(self.device)
                net.zero_grad()
                log_probs=net(images)
                loss=self.criterion(log_probs,labels)
                loss.backward()
                optimizer.step()
                logging.info('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,batch_idx*len(images),len(self.local_training_data.dataset), 100. * batch_idx / len(self.local_training_data), loss.item()))
        return net.cpu().state_dict(),sum(epoch_loss)/len(epoch)
    def local_test(self,model_global,b_use_test_dataset=False):
        model_global.eval()
        model_global.to(self.device)
        test_loss=test_acc=test_total=0
        if b_use_test_dataset:
            test_data=self.local_test_data
        else:
            test_data=self.local_training_data
        with torch.no_grad():
            for batch_idx,(x,target) in enumerate(test_data):
                x=x.to(self.device)
                target=target.to(self.device)
                pred=model_global(x)
                loss=self.criterion(pred,target)
                _,predicted=torch.max(pred,1)
                correct=predicted.eq(target).sum()
                test_acc+=correct.item()
                test_loss+=loss.item()*target.size(0)
                test_total+=target.size(0)
        return test_acc,test_total,test_loss
    
    def global_test(self,model_global,global_test_data):
        model_global.eval()
        model_global.to(self.device)
        test_loss=test_acc=test_total=0
        with torch.no_grad():
            for batch_idx,(x,target) in enumerate(global_test_data):
                x=x.to(self.device)
                target=target.to(self.device)
                pred=model_global(x)
                loss=self.criterion(pred,target)
                _,predicted=torch.max(pred,1)
                correct=predicted.eq(target).sum()
                test_acc+=correct.item()
                test_loss+=loss.item()*target.size(0)
                test_total+=target.size(0)
        return test_acc,test_total,test_loss

fedavg训练代码

import copy
import logging

import torch
import wandb
from torch import nn

from fedml_api.standalone.fedavg.client import Client

class FedAvgTrainer(object):
    def __init__(self,dataset,model,device,args):
        self.device=device
        self.args=args
        [train_data_num, test_data_num, train_data_global, test_data_global,
         data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset
        self.class_num = class_num
        self.test_global=test_data_global
        self.train_data_num=train_data_num
        self.test_data_num=test_data_num
        self.model_global=model
        self.model_global.train()
        self.client_list=[]
        self.setup_clients(data_local_num_dict, train_data_local_dict, test_data_local_dict)
    def setup_clients(self,data_local_num_dict,train_data_local_dict,test_data_local,dict):
        logging.info("############setup_clients (START)#############")
        for clinet_idx in range(self.args.client_number):
            c=Client(train_data_local_dict[client_idx],test_data_local_dict[client_idx],
                     data_local_num_dict[clinet_idx],self.args,self.device)
            self.client_list.append(c)
        logging.info("############setup_clients (END)#############")
    def train(self):
        for round_idx in range(self.args.comm_round):
            logging.info("Communication round : {}".format(round_idx))
            self.model_global.train()
            w_locals,loss_locals=[],[]
            for idx,client in enumerate(self.client_list):
                w,loss=client.train(net=copy.deepcopy(self.model_global))
                w_locals.append((client.get_sample_number(),copy.deepcopy((w))))
                loss_locals.append(copy.deepcopy(loss))
            w_glob=self.aggregate(w_locals)
            self.model_global.load_state_dict(w_glob)
            loss_avg=sum(loss_locals)/len(loss_locals)
            logging.info('Round {:3d}, Average loss {:.3f}'.format(round_idx, loss_avg))
            self.local_test(self.model_global,round_idx)
    def aggregate(self,w_locals):
        logging.info("################aggregate: %d" % len(w_locals))
        (num0,averaged_params)=w_locals[0]
        for k in averaged_params.keys():
            for i in range(0,len(w_locals)):
                local_sample_number,local_model_params=w_locals[i]
                w=local_sample_number/self.train_data_num
                if i==0:
                    averaged_params[k]=local_model_params[k]*w
                else:
                    averaged_params[k]+=local_model_params[k]*w
            return averaged_params

    def local_test(self,model_global,round_idx):
        self.local_test_on_training_data(model_global,round_idx)
        self.local_test_on_test_data(model_global,round_idx)
    def local_test_on_training_data(self,model_global,round_idx):
        num_samples=[]
        tot_corrects=[]
        losses=[]
        for c in self.client_list:
            tot_correct,num_sample,loss=c.local_test(model_global,False)
            tot_correct.append(copy.deepcopy(tot_correct))
            num_samples.append(copy.deepcopy(num_sample))
            losses.append(copy.deepcopy(loss))
        train_acc=sum(tot_corrects)/sum(num_samples)
        train_loss=sum(losses)/sum(num_samples)
        wandb.log({"Train/AccTop1": train_acc, "round": round_idx})
        wandb.log({"Train/Loss": train_loss, "round": round_idx})

        stats = {'training_acc': train_acc, 'training_loss': train_loss, 'num_samples': num_samples}
        logging.info(stats)
    def local_test_on_test_data(self,model_global,round_idx):
        num_samples=[]
        tot_corrects=[]
        losses=[]
        for c in self.client_list:
            tot_correct,num_sample,loss=c.local_test(model_global,True)
            tot_corrects.append(copy.deepcopy(tot_correct))
            num_samples.append(copy.deepcopy(num_sample))
            losses.append(copy.deepcopy(loss))
        test_acc=sum(tot_corrects)/sum(num_samples)
        test_loss=sum(losses)/sum(num_samples)
        wandb.log({"Test/AccTop1": test_acc, "round": round_idx})
        wandb.log({"Test/Loss": test_loss, "round": round_idx})
        stats={'test_acc': test_acc, 'test_loss': test_loss, 'num_samples': num_samples}
        logging.info(stats)
    def global_test(self):
        logging.info("################global_test")
        acc_train,num_sample,loss_train=self.test_using_global_dataset(self.model_global,
                                                                       self.train_global,
                                                                       self.device)
        acc_train=acc_train/num_sample

        acc_test,num_sample,loss_test=self.test_using_global_dataset(self.model_global,
                                                                     self.device)
        acc_test=acc_test/num_sample

        logging.info("Global Training Accuracy: {:.2f}".format(acc_train))
        logging.info("Global Testing Accuracy: {:.2f}".format(acc_test))
        wandb.log({"Global Training Accuracy": acc_train})
        wandb.log({"Global Testing Accuracy": acc_test})

    def test_using_global_dataset(self,model_global,global_test_device):
        model_global.eval()
        model_global.to(self.device)
        test_loss=test_acc=test_total=0
        criterion=nn.CrossEntropyLoss().to(self.device)
        with torch.no_grad():
            for batch_idx,(x,target) in enumerate(global_test_device):
                x=x.to(self.device)
                target=target.to(self.device)
                pred=model_global(x)
                loss=criterion(pred,target)
                _,predicted=torch.max(pred,1)
                correct=predicted.eq(target).sum()
                test_acc+=correct.item()
                test_loss+=loss.item()*target.size(0)
                test_total+=target.size(0)
        return test_acc,test_total,test_loss
    

主函数:

import argparse
import logging
import os
import sys

import numpy as np
import torch
import wandb

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_cifar100
from fedml_api.data_preprocessing.cinic10.data_loader import load_partition_data_cinic10
from fedml_api.model.deep_neural_networks.mobilenet import mobilenet
from fedml_api.model.deep_neural_networks.resnet import resnet56
from fedml_api.standalone.fedavg.fedavg_trainer import FedAvgTrainer

def add_args(parser):
    parser.add_argument('--model',type=str,default='resnet56',metavar='N',
                        help='neural network used in training')
    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001)

    parser.add_argument('--epochs', type=int, default=5, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
                        help='the approximate fixed number of data points we will have on each local worker')

    parser.add_argument('--client_number', type=int, default=4, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--comm_round', type=int, default=10,
                        help='how many round of communications we shoud use')

    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')
    args = parser.parse_args()
    return args

if __name__=="__main__":
    logging.basicConfig()
    logger=logging.getLogger()
    logger.setLevel(logging.DEBUG)

    args = add_args(argparse.ArgumentParser(description='FedAvg-standalone'))
    device=torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu")
    logger.info(device)
    wandb.init(
        project="fedml",
        name="FedAVG-r" + str(args.comm_round) + "-e" + str(args.epochs) + "-lr" + str(args.lr),
        config=args
    )

    np.random.seed(0)
    torch.manual_seed(10)

    # load data
    data_loader = None
    if args.dataset == "cifar10":
        data_loader = load_partition_data_cifar10
    elif args.dataset == "cifar100":
        data_loader = load_partition_data_cifar100
    elif args.dataset == "cinic10":
        data_loader = load_partition_data_cinic10
    else:
        data_loader = load_partition_data_cifar10
    train_data_num, test_data_num, train_data_global, test_data_global, \
    data_local_num_dict, train_data_local_dict, test_data_local_dict, \
    class_num = data_loader(args.dataset, args.data_dir, args.partition_method,
                            args.partition_alpha, args.client_number, args.batch_size)

    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    model = None
    if args.model == "resnet56":
        model = resnet56(class_num)
    elif args.model == "mobilenet":
        model = mobilenet(class_num=class_num)

    trainer = FedAvgTrainer(dataset, model, device, args)
    trainer.train()
    trainer.global_test()

你可能感兴趣的:(python,深度学习,人工智能)