Scaffold 基于fedavg方法的改进,代码复现(联邦学习)

当前有个工作需要实现scaffold算法,该方法通过添加修正项c来解决客户端漂移现象,
在参考github上的相关框架后,复现了该算法。
算法分为三个模块:
optimizer: 重写优化器sdg
clientscaffold:客户端操作
serverscaffold:服务端操作

optimizer部分代码:

import torch
from torch.optim import Optimizer

class SCAFFOLDOptimizer(Optimizer):
    def __init__(self, params, lr, weight_decay):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(SCAFFOLDOptimizer, self).__init__(params, defaults)
        pass

    def step(self, server_controls, client_controls, closure=None):
        loss = None
        if closure is not None:
            loss = closure

        # for group, c, ci in zip(self.param_groups, server_controls, client_controls):
        #     p = group['params'][0]
        #     if p.grad is None:
        #         continue
        #     d_p = p.grad.data + c.data - ci.data
        #     p.data = p.data - d_p.data * group['lr']
        for group in self.param_groups:
            for p, c, ci in zip(group['params'], server_controls, client_controls):
                if p.grad is None:
                    continue
                d_p = p.grad.data + c.data - ci.data #这里实现用c来更新本地模型
                p.data = p.data - d_p.data * group['lr']
        return loss

serverscaffold:

from flcore.clients.clientscaffold import clientScaffold
from flcore.servers.serverbase import Server
from utils.data_utils import read_client_data
from threading import Thread
import torch
import random

class Scaffold(Server):
    def __init__(self, device, dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
                 num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, time_threthold):
        super().__init__(dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
                         num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, 
                         time_threthold)
        # select slow clients
        self.set_slow_clients()

        self.global_model=model
        for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self.send_slow_clients):
            train, test = read_client_data(dataset, i)
            client = clientScaffold(device, i, train_slow, send_slow, train, test, model, batch_size, learning_rate, local_steps)
            self.clients.append(client)


        print(f"\nJoin clients / total clients: {self.join_clients} / {self.num_clients}")
        
        self.server_controls = [torch.zeros_like(p.data) for p in model.parameters() if p.requires_grad]


    def train(self):
        for i in range(self.global_rounds+1):
            self.send_parameters() #发送修正项 c
            
            if i%self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate global model")
                self.evaluate()

            self.selected_clients = self.select_clients()
            for client in self.selected_clients:
                client.train()

            self.aggregate_parameters()
            


        print("\nBest global results.")
        self.print_(max(self.rs_test_acc), max(
            self.rs_train_acc), min(self.rs_train_loss))

        self.save_results()
        self.save_global_model()

   
    def send_parameters(self):
        assert (len(self.clients) > 0)
        for client in self.clients:
            client.set_parameters(self.global_model)
            for control, new_control in zip(client.server_controls, self.server_controls):
                control.data = new_control.data

    def aggregate_parameters(self):
        assert (len(self.selected_clients) > 0)

        active_clients = random.sample(
            self.selected_clients, int((1-self.client_drop_rate) * self.join_clients))

        active_train_samples = 0
        for client in active_clients:
            active_train_samples += client.train_samples

        self.uploaded_weights = []
        for client in active_clients:
            self.uploaded_weights.append(client.train_samples / active_train_samples)
        for user,w in zip(active_clients,self.uploaded_weights):
            self.add_parameters(user, active_train_samples,w)


    def add_parameters(self, user, total_samples,w):
        num_of_selected_users = self.join_clients#len(self.selected_clients)
        num_of_users = self.num_clients
        num_of_samples = user.train_samples
        for param, control, del_control, del_model in zip(self.global_model.parameters(), self.server_controls,
                                                          user.delta_controls, user.delta_model):
            #因为数据不是独立同分布,所以采用每个客户端的样本比例来替代客户端数量                  
            # param.data = param.data + del_model.data / num_of_selected_users
            # control.data = control.data + del_control.data / num_of_users

            param.data = param.data + del_model.data *w
            control.data = control.data + del_control.data *w

clientscaffold:

import torch
import torch.nn as nn
from flcore.clients.clientbase import Client
import numpy as np
import time
import copy
from flcore.optimizers.fedoptimizer import *
import math
class clientScaffold(Client):
    def __init__(self, device, numeric_id, train_slow, send_slow, train_data, test_data, model, batch_size, learning_rate,
                 local_steps):
        super().__init__(device, numeric_id, train_slow, send_slow, train_data, test_data, model, batch_size, learning_rate,
                         local_steps)

        self.loss = nn.CrossEntropyLoss()

        #self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
        
        L=0 # Regularization term 用的它默认值
        #这里模型用的重写的优化器
        self.optimizer = SCAFFOLDOptimizer(self.model.parameters(), lr=self.learning_rate, weight_decay=L)
       
        self.controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
        self.server_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
        self.delta_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
        #更新模型
        self.delta_model = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
        self.server_model = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]

        self.local_model = copy.deepcopy(list(self.model.parameters()))


    def set_grads(self, new_grads):
        if isinstance(new_grads, nn.Parameter):
            for model_grad, new_grad in zip(self.model.parameters(), new_grads):
                model_grad.data = new_grad.data
        elif isinstance(new_grads, list):
            for idx, model_grad in enumerate(self.model.parameters()):
                model_grad.data = new_grads[idx]

    def train(self):
        start_time = time.time()

        # self.model.to(self.device)
        self.model.train()
        
        #暂时用scaffold论文第2种方法更新本地的c
        grads = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
        #self.get_grads(grads)

        max_local_steps = self.local_steps
        if self.train_slow:
            max_local_steps = np.random.randint(1, max_local_steps // 2)

        for step in range(max_local_steps):

            if self.train_slow:
                time.sleep(0.1 * np.abs(np.random.rand()))
            x, y = self.get_next_train_batch()
            self.optimizer.zero_grad()
            output = self.model(x)
            loss = self.loss(output, y)
            loss.backward()
            self.optimizer.step(self.server_controls, self.controls)

        # get model difference #得到当前和服务端模型的差异
        for local, server, delta in zip(self.model.parameters(), self.server_model, self.delta_model):
            delta.data = local.data.clone() - server.data.clone()

        # get client new controls,对应论文的两种操作
        new_controls = [torch.zeros_like(p.data) for p in self.model.parameters() if p.requires_grad]
        opt = 2
        if opt == 1:
            for new_control, grad in zip(new_controls, grads):
                new_control.data = grad.grad
        if opt == 2:
            for server_control, control, new_control, delta in zip(self.server_controls, self.controls, new_controls,
                                                                   self.delta_model):
                a = 1 / (math.ceil(self.train_samples / self.batch_size) * self.learning_rate)
                new_control.data = control.data - server_control.data - delta.data * a

        # get controls differences
        for control, new_control, delta in zip(self.controls, new_controls, self.delta_controls):
            delta.data = new_control.data - control.data
            control.data = new_control.data

        self.train_time_cost['num_rounds'] += 1
        self.train_time_cost['total_cost'] += time.time() - start_time
        
    def set_parameters(self, server_model):
        for old_param, new_param, local_param, server_param in zip(self.model.parameters(), server_model.parameters(), self.local_model, self.server_model):
            old_param.data = new_param.data.clone()
            local_param.data = new_param.data.clone()
            server_param.data = new_param.data.clone()

你可能感兴趣的:(联邦学习,机器学习)