Federated Continual Learning with Weighted Inter-client Transfer——论文笔记

一. 简介

持续学习是一种序列化任务学习方式使得机器能够像人类一样不断去学习新知识而避免灾难性遗忘。然而,即便这样,这些模型根本上还是存在缺陷,因为每一个模型只能从直接经验去学习知识(意思是,只能对于自己的数据集去进行学习)。但是人类可以从书籍、视频等方式去获取他人的经验。那么不同机器也也可以存在这样一个方式去进行信息交换和学习,然而这样又到来隐私和通信的问题,而处理这一问题的方法就是联邦学习。
联邦学习是通过交换参数而不是原始数据本身来解决上述问题。然而简简单单的结合将会带来新的挑战。**首先是,持续的联邦学习不仅会带来灾难性遗忘,还会带来来自其他客户潜在的干扰。**因此我们需要有选择地利用来自其他客户的知识,以最小化客户间的干扰,最大化进行客户间的知识转移。第二个问题是联邦学习之间进行通信交换知识时,可能会造成通信成本过大
应对上述这两个问题,作者提出一个新的联邦持续学习框架——联邦加权客户间的传输(FedWeIT)。这个模型将局部模型参数分解为稠密基参数和稀疏任务自适应参数,以便进行更高效地通信。

二. 联邦持续学习与FedWeIT

2.1 问题定义

在标准的持续学习中,模型需要迭代的学习一系列任务 { T 1 , T 2 , T 3 . . . T T } \{T_1,T_2,T_3...T_T\} { T1,T2,T3...TT},每一个 T t T_t Tt代表一个特定的任务,由N个 { X , Y } \{X,Y\} { X,Y}组成。假设之前任务在训练之后不会出现,因为我们的目标是:

min ⁡ θ ( t ) L ( θ ( t ) ; θ ( t − 1 ) , T ( t ) ) \min_{\theta^{(t)}}L(\theta^{(t)};\theta^{(t-1)},T^{(t)}) θ(t)minL(θ(t);θ(t1),T(t))

接下来我们将拓展到联邦学习环境下。假设我们有C个客户端 c ∈ { c 1 , . . . , c C } c\in\{c_1,...,c_C\} c{ c1,...,cC},在私有可访问序列上训练任务 { T 1 , T 2 , T 3 . . . T T } \{T_1,T_2,T_3...T_T\} { T1,T2,T3...TT}。而我们的目标就是高校地训练这些客户端上的持续学习模型通过与全局服务器通信模型参数(FedAvg)

2.2 可传达的持续性学习

首先,我们需要服务器从客户端中选择一部分客户端进行训练,客户端训练后会将当前的参数传入服务器。服务器接收后将这些参数汇聚成一个参数 θ ( r ) \theta^{(r)} θ(r),这就是当前最常用的两种框架(FedAvg,FedProx)。然而简单的这样必定会导致客户端出现灾难性遗忘,这里有几种常见的方法去固化以前的参数,冻结重要的权重(EWC等)。
因此,我们只需要关注于联邦持续学习中出现的新挑战。在这个环境中,参数聚合时允许跨客户间的知识转移。因为一个任务在客户端i上训练到第q轮的时候可能会和另一个客户端j上训练到第p轮时存在知识的相关性。同时如此多的训练要求下,我们需要考虑到通信的问题。

2.3 FedWIT

由于多个客户端学到的所有任务的知识都存储在一组参数 θ g \theta_g θg中,然而如果为了进行知识转移,每一个客户端都应该只利用其他客户培训的相关任务的知识而忽略掉干扰学习的不相关的任务的知识。
针对这个,作者使用的是分解参数的方式,分解如下:

θ ( t ) = B c ( t ) ⊙ m c ( t ) + A c ( t ) + ∑ i ∈ C ∑ j < ∣ t ∣ α i , j ( t ) A i ( j ) \theta^{(t)}=B^{(t)}_c \odot m_c^{(t)}+A^{(t)}_c+\sum_{i\in C}\sum_{j<|t|}\alpha_{i,j}^{(t)}A_i^{(j)} θ(t)=Bc(t)mc(t)+Ac(t)+iCj<tαi,j(t)Ai(j)

其中B表示是在第c个客户端上共享的参数(也就是服务器聚合的参数)
,m则是保证客户端接受不了到的B参数不被不相关的因素影响。A为任务自适应参数,设置A是因为需要捕获第一项没有捕获到的关于此任务的知识。最后一部分描述了加权的客户端知识转移。包括来自所有客户端所有任务自适应参数(对应A),以及为了有效选择这些其他客户的间接经验,还有对应的注意力参数 α \alpha α
参数分解之后可以定义新的优化目标:

min ⁡ B c ( t ) , m c ( t ) , A c ( 1 : t ) , α c ( t ) L ( θ c ( t ) ; T c ( t ) ) + λ 1 Ω ( m c ( t ) , A c ( 1 : t ) ) + λ 2 ∑ i = 1 t − 1 ∣ ∣ Δ B c ( t ) ⊙ m c ( i ) + Δ A c ( i ) ∣ ∣ 2 2 \min_{B_c^{(t)},m_c^{(t)},A_c^{(1:t)},\alpha_c^{(t)}}L(\theta_c^{(t)};T_c^{(t)})+\lambda_1\Omega({m_c^{(t)}},A_c^{(1:t)})+\lambda_2\sum_{i=1}^{t-1}||\Delta B_c^{(t)} \odot m_c^{(i)}+\Delta A_c^{(i)}||^2_2 Bc(t),mc(t),Ac(1:t),αc(t)minL(θc(t);Tc(t))+λ1Ω(mc(t),Ac(1:t))+λ2i=1t1ΔBc(t)mc(i)+ΔAc(i)22

其中L为损失函数, Ω ( . ) \Omega(.) Ω(.)是所有任务自适应参数和mask参数的稀疏正则化项(作者使用l1-normal)。第二个正则化项用于追溯更新过去的任务自适应参数,通过反应基础参数的变化,帮助任务自适应参数保目标任务的原始解。在这一部分 Δ B c ( t ) \Delta B_c^{(t)} ΔBc(t)表示当前训练的任务与上一个任务基本参数之间的差异。 Δ A c ( i ) \Delta A_c^{(i)} ΔAc(i)是当前任务与上一个任务自适应参数之间的差异。 λ 1 和 λ 2 \lambda_1和\lambda_2 λ1λ2都是超参数。

2.4 算法简介

客户端: 在每一轮r中,每一个客户端 C c C_c Cc利用服务传入的参数更新自己的基本参数: B c ( n ) = θ G ( n ) B_c(n)=\theta_G(n) Bc(n)=θG(n)。之后利用上面的优化函数进行更新, B c ( t ) = B c ( t ) ⊙ m c ( t ) B_c^{(t)}=B_c^{(t)} \odot m_c^{(t)} Bc(t)=Bc(t)mc(t),和任务自适应参数 A c ( t ) A_c^{(t)} Ac(t)传到服务端。
服务端:服务端首先聚合来自各个客户端的参数: θ G = 1 C ∑ C B c ( t ) \theta_G = \frac{1}{C}\sum_C{B_c^{(t)}} θG=C1CBc(t),之后将这个新的参数传给每一个客户端。对于t-1任务自适应参数 { A i ( t − 1 ) } i = 1 C c \{A_i^{(t-1)}\}_{i=1}^{C_{\\ c}} { Ai(t1)}i=1Cc在训练任务t的时候对每一个客户端进行一次广播。
算法如图所示:
Federated Continual Learning with Weighted Inter-client Transfer——论文笔记_第1张图片

三. 代码配合详解

作者的代码地址点这里,作者是用tensorflow写的,之后有时间我会把他改写成pytorch版本供大家看,具体的代码细节大家可以看作者提供的github地址,这里只针对关键算法进行分析。

3.1 服务端

首先是对于服务端需要初始化全局参数 θ G \theta_G θG,以及初始化客户端(也就是创建几个客户端),之后针对客户端开始训练

def run(self):
    self.logger.print('server', 'started')
    self.start_time = time.time()
    self.init_global_weights()
    self.init_clients()
    self.train_clients()

在训练客户端函数中,针对训练任务数以及每一个任务交换权重数目进行训练,

def train_clients(self):
	#选择一部分客户端进行训练
    cids = np.arange(self.args.num_clients).tolist()
    num_selection = int(round(self.args.num_clients*self.args.frac_clients))
    #对应的需要训练(任务数*交换轮数)
    for curr_round in range(self.args.num_rounds*self.args.num_tasks):
        self.updates = []
        self.curr_round = curr_round+1
        self.is_last_round = self.curr_round%self.args.num_rounds==0
        if self.is_last_round:
            self.client_adapts = []
        selected_ids = random.sample(cids, num_selection) # pick clients
        self.logger.print('server', 'round:{} train clients (selected_ids: {})'.format(curr_round, selected_ids))
        # 并行地训练对应的客户端
        for clients in self.parallel_clients:
            self.threads = []
            for gid, cid in enumerate(clients):
                client = self.clients[gid]
                selected = True if cid in selected_ids else False
                这里用threading.Tread创建多线程,进行训练,注意里面的参数:包括客户端、当前轮数、获得的参数以及之前任务集对应的knowledge base参数
                with tf.device('/device:GPU:{}'.format(gid)):
                    thrd = threading.Thread(target=self.invoke_client, args=(client, cid, curr_round, selected, self.get_weights(), self.get_adapts()))
                    self.threads.append(thrd)
                    thrd.start()
            # 等待全部客户端训练完
            for thrd in self.threads:
                thrd.join()
        # 聚合对应的参数,也就是mask ⊙ B
        aggr = self.train.aggregate(self.updates)
        self.set_weights(aggr)
    self.logger.print('server', 'done. ({}s)'.format(time.time()-self.start_time))
    sys.exit()

这是作者根据算法实现的整体框架,下面我们针对具体的部分来看一看。
首先是在每一个客户端开始训练前所需要的参数,包括服务器聚合的参数和knowledge base(以前任务的自适应参数)
θ G \theta_G θG直接从服务器获得即可,这里不多描述,而kb参数则是对每一个任务训练完之后进行合并

def get_adapts(self):
	# 只有当训练完一个任务的时候才添加
    if self.curr_round%self.args.num_rounds==1 and not self.curr_round==1:
        from_kb = []
        # 对每一个layer对应的kb参数进行添加
        for lid, shape in enumerate(self.nets.shapes):
            shape = np.concatenate([self.nets.shapes[lid],[int(round(self.args.num_clients*self.args.frac_clients))]], axis=0)
            from_kb_l = np.zeros(shape)
            # 每一个客户端进行添加对应的位置
            for cid, ca in enumerate(self.client_adapts):
                try:
                    if len(shape)==5:
                        from_kb_l[:,:,:,:,cid] = ca[lid]
                    else:
                        from_kb_l[:,:,cid] = ca[lid]
                except:
                    pdb.set_trace()           
            from_kb.append(from_kb_l)
        return from_kb
    else:
        return None

可以发现,如果当前训练的参数的shape为[x,y,z],总共有c个客户端参与训练,那么kb的shape应该为[x,y,z,c](也就是有c个对应的参数)
服务器除了需要聚合kb参数,还需要聚合全局参数(B),对应代码如下:

# updates参数包括了所有客户端的共享参数B和对应的mask
def aggregate(self, updates):
	#使用FedWIT框架
    if self.args.sparse_comm and self.args.model in ['fedweit']:
    	## 获取每一个共享参数B
        client_weights = [u[0][0] for u in updates]
        ## 获取对应的mask
        client_masks = [u[0][1] for u in updates]
        client_sizes = [u[1] for u in updates]
        ## 将要聚合的参数(初始化,为0)
        new_weights = [np.zeros_like(w) for w in client_weights[0]]
        ## 这个设置防止出现/0的操作
        epsi = 1e-15
        total_sizes = epsi
        ## 这里的mask经过转换,方便进行后面的工作,值没有变
        client_masks = tf.ragged.constant(client_masks, dtype=tf.float32)
        ## 将每一个客户端上各自对应的layer上的mask合并
        for _cs in client_masks:
            total_sizes += _cs
         ## 平均操作 这一步就相当于 average(B⊙mask),至于为什么写看之后的步骤
        for c_idx, c_weights in enumerate(client_weights): # by client
            for lidx, l_weights in enumerate(c_weights): # by layer
                ratio = 1/total_sizes[lidx]
                new_weights[lidx] += tf.math.multiply(l_weights, ratio).numpy()
    else:
        client_weights = [u[0] for u in updates]
        client_sizes = [u[1] for u in updates]
        new_weights = [np.zeros_like(w) for w in client_weights[0]]
        total_size = len(client_sizes)
        for c in range(len(client_weights)): # by client
            _client_weights = client_weights[c]
            for i in range(len(new_weights)): # by layer
                new_weights[i] +=  _client_weights[i] * float(1/total_size)
    return new_weights

我们完成聚合中发现有两个难理解的,一个是mask加起来的意义是什么,还有就是为什么求平均要这么算。除了这两个之外,我们应该还需要注意到的是,在论文中提到的B是一个高度稀疏的矩阵,而按照传统的loss的求,我们无法做到稀疏,这里就要看作者这部分的操作

def get_weights(self):
    if self.args.model in ['fedweit']:
        if self.args.sparse_comm:
            hard_threshold = []
            sw_pruned = []
            # 获得mask
            masks = self.nets.decomposed_variables['mask'][self.state['curr_task']]
            for lid, sw in enumerate(self.nets.decomposed_variables['shared']):
                mask = masks[lid]
                # 针对mask进行排序,目的是找出一个阈值=mask的长度*选取客户端的数量
                m_sorted = tf.sort(tf.keras.backend.flatten(tf.abs(mask)))
                thres = m_sorted[math.floor(len(m_sorted)*(self.args.client_sparsity))]
                # 将大于这个阈值的部分mask取1,其余取0
                m_bianary = tf.cast(tf.greater(tf.abs(mask), thres), tf.float32).numpy().tolist()
                hard_threshold.append(m_bianary)
                # 对应B*mask,进行稀疏
                sw_pruned.append(sw.numpy()*m_bianary)
            self.train.calculate_communication_costs(sw_pruned)
            return sw_pruned, hard_threshold
        else:
            return [sw.numpy() for sw in self.nets.decomposed_variables['shared']]
    else:
        return self.nets.get_body_weights()

在经过这几部,就发现解决了上面的一位,首先是B已经被mask稀疏,mask变为一堆0和1组成的参数,可以理解为重要的因子,所以之后求平均的时候,我们只需要累加那些对应mask为1的部分,再将我们的B去除以这个即可。

3.2 客户端

之后我们来看看客户端的训练中一些不同的地方
根据算法,我们客户端获取到对应的参数后,将客户端的sw(B参数)进行更新,之后按照在第二节中的所说的损失函数进行优化即可,loss如下:

def loss(self, y_true, y_pred):
    weight_decay, sparseness, approx_loss = 0, 0, 0
    # 第一部分直接损失
    loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    for lid in range(len(self.nets.shapes)):
    	# 三个对应的参数 B A mask
        sw = self.nets.get_variable(var_type='shared', lid=lid)
        aw = self.nets.get_variable(var_type='adaptive', lid=lid, tid=self.state['curr_task'])
        mask = self.nets.get_variable(var_type='mask', lid=lid, tid=self.state['curr_task'])
        # 第二部分参数,用来l2loss去求出A和mask的损失
        weight_decay += self.args.wd * tf.nn.l2_loss(aw)
        weight_decay += self.args.wd * tf.nn.l2_loss(mask)
        sparseness += self.args.lambda_l1 * tf.reduce_sum(tf.abs(aw))
        sparseness += self.args.lambda_mask * tf.reduce_sum(tf.abs(mask))
        # 如果训练到之后的任务,则需要加上以前任务的信息
        if self.state['curr_task'] == 0:
            weight_decay += self.args.wd * tf.nn.l2_loss(sw)
        else:
            for tid in range(self.state['curr_task']):
                prev_aw = self.nets.get_variable(var_type='adaptive', lid=lid, tid=tid)
                prev_mask = self.nets.get_variable(var_type='mask', lid=lid, tid=tid)
                g_prev_mask = self.nets.generate_mask(prev_mask)
                #################################################
                # 第三部分 B*mask + A
                restored = sw * g_prev_mask + prev_aw
                # 变成\delta B * mask + \delta A
                a_l2 = tf.nn.l2_loss(restored-self.state['prev_body_weights'][lid][tid])
                approx_loss += self.args.lambda_l2 * a_l2
                #################################################
                # 第二部分也需要加上以前任务的对应的\Omega 损失
                sparseness += self.args.lambda_l1 * tf.reduce_sum(tf.abs(prev_aw))
    
    loss += weight_decay + sparseness + approx_loss 
    return loss

loss定义完成后,更新B mask A还有所需要的bias这四个参数即可,这里我粘贴上代码运行中实际更新的参数(sw位B adaptive为A)
Federated Continual Learning with Weighted Inter-client Transfer——论文笔记_第2张图片

按照这个方式更新,之后就聚合B(3.1中部分),合并A为kb即可。
最后说明一下,文中提到的参数分解部分是如何改造的,作者这里是对卷积层和Dense层进行了继承,从而进行改造的。

class DecomposedConv(tf.keras.layers.Conv2D):
  """ Custom conv layer that decomposes parameters into shared and specific parameters.
  
  Base code is referenced from official tensorflow code (https://github.com/tensorflow/tensorflow/)

  Created by:
      Wonyong Jeong ([email protected])
  """
  def __init__(self, 
               filters,
               kernel_size,
               rank=2,
               strides=(1, 1),
               padding='valid',
               data_format=None,
               dilation_rate=(1, 1),
               activation=None,
               use_bias=False,
               kernel_initializer='glorot_uniform',
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               trainable=True,
               name=None,
               lambda_l1=None,
               lambda_mask=None,
               shared=None,
               adaptive=None,
               from_kb=None,
               atten=None,
               mask=None,
               bias=None,
               **kwargs):
    # 初始化
    super(DecomposedConv, self).__init__(
               filters=filters,
               kernel_size=kernel_size,
               strides=strides,
               padding=padding,
               data_format=data_format,
               dilation_rate=dilation_rate,
               activation=activation,
               use_bias=use_bias,
               kernel_initializer=kernel_initializer,
               bias_initializer=bias_initializer,
               kernel_regularizer=kernel_regularizer,
               bias_regularizer=bias_regularizer,
               activity_regularizer=activity_regularizer,
               kernel_constraint=kernel_constraint,
               bias_constraint=bias_constraint,
               trainable=trainable,
               name=name, **kwargs)
    # 我们需要的参数 B A mask bias 以及alpha(对应atten)
    self.sw   = shared
    self.aw   = adaptive
    self.mask = mask
    self.bias = bias
    self.aw_kb = from_kb
    self.atten = atten
    self.lambda_l1   = lambda_l1
    self.lambda_mask = lambda_mask

  def l1_pruning(self, weights, hyp):
    hard_threshold = tf.cast(tf.greater(tf.abs(weights), hyp), tf.float32)
    return tf.multiply(weights, hard_threshold)
  # 前向传播计算
  def call(self, inputs):
    ################################################################################
    aw = self.aw if tf.keras.backend.learning_phase() else self.l1_pruning(self.aw, self.lambda_l1)
    mask = self.mask if tf.keras.backend.learning_phase() else self.l1_pruning(self.mask, self.lambda_mask)
    atten = self.atten
    aw_kbs = self.aw_kb
    ############################### Decomposed Kernel #################################
    # 这里就是将几个参数合并为 \theta 这里的kb是从服务端继承过来的
    self.my_theta = self.sw * mask + aw + tf.keras.backend.sum(aw_kbs * atten, axis=-1)
    ###################################################################################

    # if self._recreate_conv_op(inputs):
    # 根据我们的呢theta进行计算
    self._convolution_op = nn_ops.Convolution(
        inputs.get_shape(),
        filter_shape=self.my_theta.shape,
        dilation_rate=self.dilation_rate,
        strides=self.strides,
        padding="SAME",)
        # data_format=self._conv_op_data_format)

    # Apply causal padding to inputs for Conv1D.
    if self.padding == 'causal' and self.__class__.__name__ == 'Conv1D':
      inputs = array_ops.pad(inputs, self._compute_causal_padding())
   
    outputs = self._convolution_op(inputs, self.my_theta)

    if self.use_bias:
      if self.data_format == 'channels_first':
        if self.rank == 1:
          # nn.bias_add does not accept a 1D input tensor.
          bias = array_ops.reshape(self.bias, (1, self.filters, 1))
          outputs += bias
        else:
          outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
      else:
        outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')

    if self.activation is not None:
      return self.activation(outputs)
    return outputs

Dense和卷积差不多,这里就不多赘述啦

四. 总结

到这里,基本上大家都能对应代码理解作者的思路了,作者还写了一个生成数据集的方法,之后我会贴上数据集的百度网盘。希望对大家能有帮助

你可能感兴趣的:(每日一次AI论文阅读)