Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析

论文地址点这里

一. 介绍

联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的。在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的数据。
在本文中,我们提出了一个基于无数据的知识蒸馏算法——FEDGEN。具体来说,FEDGEN学习一个仅从用户模型的预测规则派生的生成模型,给定一个目标标签,可以生成与用户预测集合一致的特征表示。该生成器随后广播给用户,在潜在空间上护送他们的模型训练与增强样本,这体现了来自其他同行用户的蒸馏知识。给定一个维数远远小于输入空间的潜在空间,FEDGEN学习到的生成器可以是轻量级的,为当前FL框架引入最小的开销。

二. 问题定义

我们使用 X ⊂ R p \mathcal{X} \subset \mathbb{R}^p XRp 表示为输入的实例空间 Z ⊂ R d \mathcal{Z} \subset \mathbb{R}^d ZRd 表示为潜在的特征空间,其中 d < p d

d<p, Y ⊂ R \mathcal{Y} \subset \mathbb{R} YR为输出空间。 T \mathcal{T} T 表示为一个具体的域(domain),其由数据样本 X \mathcal{X} X 组成的数据分布 D \mathcal{D} D 和一个真实标签的函数: c ∗ : X → Y c^*: \mathcal{X} \rightarrow \mathcal{Y} c:XY组成,也就是 T : = ⟨ D , c ∗ ⟩ \mathcal{T}:=\left\langle\mathcal{D}, c^*\right\rangle T:=D,c。注意,在本文中将任务和域当作一样对待。模型参数 θ : = [ θ f ; θ p ] \boldsymbol{\theta}:=\left[\boldsymbol{\theta}^f ; \boldsymbol{\theta}^p\right] θ:=[θf;θp]由两个部分组成:一个特征提取器 f : X → Z f: \mathcal{X} \rightarrow \mathcal{Z} f:XZ(对应参数为 θ f \boldsymbol{\theta}^f θf), 一个预测器 h : Z → △ Y h: \mathcal{Z} \rightarrow \triangle^{\mathcal{Y}} h:ZY (由参数 θ p \boldsymbol{\theta}^p θp组成),其中 Δ Y \Delta^{\mathcal{Y}} ΔY 表示单独的一个 Y \mathcal{Y} Y。给定一个凸的损失函数 l : △ Y × Y → R l: \triangle^{\mathcal{Y}} \times \mathcal{Y} \rightarrow \mathbb{R} l:Y×YR, 模型参数 θ \boldsymbol{\theta} θ 在任务 T \mathcal{T} T上的损失表示为 L T ( θ ) : = \mathcal{L}_{\mathcal{T}}(\boldsymbol{\theta}):= LT(θ):= E x ∼ D [ l ( h ( f ( x ; θ f ) ; θ p ) , c ∗ ( x ) ) ] \mathbb{E}_{x \sim \mathcal{D}}\left[l\left(h\left(f\left(x ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right), c^*(x)\right)\right] ExD[l(h(f(x;θf);θp),c(x))]
联邦学习 致力于学习一个全局的模型参数 θ \theta θ,其能在每个客户端上都能达到最小的损失:
min ⁡ θ E T k ∈ T [ L k ( θ ) ] (1) \min _{\boldsymbol{\theta}} \mathbb{E}_{\mathcal{T}_k \in \mathcal{T}}\left[\mathcal{L}_k(\boldsymbol{\theta})\right] \tag1 θminETkT[Lk(θ)](1)
其中 T = { T k } k = 1 K \mathcal{T}=\left\{\mathcal{T}_k\right\}_{k=1}^K T={Tk}k=1K表示所有的客户端的任务。 我们假设所有的任务共享相同的标签规则 c ∗ c^* c以及损失函数 T k = ⟨ D k , c ∗ ⟩ \mathcal{T}_k=\left\langle\mathcal{D}_k, c^*\right\rangle Tk=Dk,c。在实际中, 等式1可以进行这样的优化: min ⁡ θ 1 K ∑ k = 1 K L ^ k ( θ ) \min _\theta \frac{1}{K} \sum_{k=1}^K \hat{\mathcal{L}}_k(\boldsymbol{\theta}) minθK1k=1KL^k(θ), 其中 L ^ k ( θ ) : = 1 ∣ D ^ k ∣ ∑ x i ∈ D ^ k [ l ( h ( f ( x i ; θ f ) ; θ p ) , c ∗ ( x i ) ) ] \hat{\mathcal{L}}_k(\boldsymbol{\theta}):=\frac{1}{\left|\hat{\mathcal{D}}_k\right|} \sum_{x_i \in \hat{\mathcal{D}}_k}\left[l\left(h\left(f\left(x_i ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right), c^*\left(x_i\right)\right)\right] L^k(θ):=D^k1xiD^k[l(h(f(xi;θf);θp),c(xi))]表示在数据集 D ^ k \hat{\mathcal{D}}_k D^k上的经验损失。 这里有个隐含的假设是:对于全局数据 D ^ \hat{\mathcal{D}} D^是分布在所有的客户端上: D ^ = ∪ { D ^ k } k = 1 K \hat{\mathcal{D}}=\cup\left\{\hat{\mathcal{D}}_k\right\}_{k=1}^K D^={D^k}k=1K

知识蒸馏(KD) 也被称为师生范式,其目标是学习轻量级的学生模型,使用从一个或多个强大的老师那里提取的知识。典型的KD利用一个代理数据集“pd”来最小化分别来自教师模型 θ T \theta^T θT和学生模型 θ S \theta^S θS的logits输出之间的差异。一个代表性的选择是使用Kullback-Leibler散度来衡量这种差异:
min ⁡ θ S E x ∼ D ^ P [ D K L [ σ ( g ( f ( x ; θ T f ) ; θ T p ) ∥ σ ( g ( f ( x ; θ S f ) ; θ S p ) ] ] \min _{\boldsymbol{\theta}_S} \mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{P}}}\left[D _ { \mathrm { KL } } \left[\sigma\left(g\left(f\left(x ; \boldsymbol{\theta}_T^f\right) ; \boldsymbol{\theta}_T^p\right) \| \sigma\left(g\left(f\left(x ; \boldsymbol{\theta}_S^f\right) ; \boldsymbol{\theta}_S^p\right)\right]\right]\right.\right. θSminExD^P[DKL[σ(g(f(x;θTf);θTp)σ(g(f(x;θSf);θSp)]]
其中 g ( ⋅ ) g(\cdot) g()表示为模型的逻辑输出, σ ( ⋅ ) \sigma(\cdot) σ()表示为激活函数。
KD的想法已经扩展到FL,以解决用户异质性,通过将每个用户模型 θ k \theta_k θk视为教师,其信息聚合到学生(全局)模型 θ \theta θ,以提高其泛化性能:
min ⁡ θ E x ∼ D ^ P [ D K L [ σ ( 1 K ∑ k = 1 K g ( f ( x ; θ k f ) ; θ k p ) ) ∥ σ ( g ( f ( x ; θ f ) ; θ p ) ] ] \min _{\boldsymbol{\theta}} \mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{P}}}\left[D_{\mathrm{KL}}\left[\sigma\left(\frac{1}{K} \sum_{k=1}^K g\left(f\left(x ; \boldsymbol{\theta}_k^f\right) ; \boldsymbol{\theta}_k^p\right)\right) \| \sigma\left(g\left(f\left(x ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right)\right]\right]\right. θminExD^P[DKL[σ(K1k=1Kg(f(x;θkf);θkp))σ(g(f(x;θf);θp)]]
上述方法的一个主要限制在于它依赖于代理数据集“ D ^ P \hat{D}_P D^P”,需要仔细考虑它的选择,并在蒸馏性能中起着关键作用。接下来,我们将展示如何以无数据的方式使KD对FL可行。

三. FEDGEN: 通过生成学习实现无数据的联邦蒸馏

本方法如图所示:
Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析_第1张图片

3.1 知识提取

我们的核心思想是提取关于数据分布的全局视图的知识,这些知识是传统FL无法观察到的,并将这些知识提取到局部模型中,以指导它们的学习。我们首先考虑学习一个条件分布 Q ∗ : Y → X Q^*: \mathcal{Y} \rightarrow \mathcal{X} Q:YX来描述这种知识,它与真实数据分布一致:
Q ∗ = arg ⁡ max ⁡ Q : Y → X E y ∼ p ( y ) E x ∼ Q ( x ∣ y ) [ log ⁡ p ( y ∣ x ) ] , (2) Q^*=\underset{Q: \mathcal{Y} \rightarrow \mathcal{X}}{\arg \max } \mathbb{E}_{y \sim p(y)} \mathbb{E}_{x \sim Q(x \mid y)}[\log p(y \mid x)], \tag2 Q=Q:YXargmaxEyp(y)ExQ(xy)[logp(yx)],(2)
其中 p ( y ) p(y) p(y)表示标签的先验概率而 p ( y ∣ x ) p(y|x) p(yx)表示为后验概率。为了能优化式2,我们替换了 p ( y ) p(y) p(y)以及 p ( y ∣ x ) p(y|x) p(yx)。首先,我们估计 p ( y ) p(y) p(y)为:
p ^ ( y ) ∝ ∑ k E x ∼ D ^ k [ I ( c ∗ ( x ) = y ) ] , \hat{p}(y) \propto \sum_k \mathbb{E}_{x \sim \hat{\mathcal{D}}_k}\left[\mathrm{I}\left(c^*(x)=y\right)\right], p^(y)kExD^k[I(c(x)=y)],
其中 I ( ⋅ ) \mathrm{I}(\cdot) I()为一个指标函数。在实际中, p ^ ( y ) \hat{p}(y) p^(y)可以使用各个客户端训练标签的数量来进行统计。下一步,我们使用各个客户端的集成知识估计 p ( y ∣ x ) p(y|x) p(yx):
log ⁡ p ^ ( y ∣ x ) ∝ 1 K ∑ k = 1 K log ⁡ p ( y ∣ x ; θ k ) \log \hat{p}(y \mid x) \propto \frac{1}{K} \sum_{k=1}^K \log p\left(y \mid x ; \boldsymbol{\theta}_k\right) logp^(yx)K1k=1Klogp(yx;θk)
有了上面的近似之后,直接在输入空间 X \mathcal{X} X上优化式子(2)依然是不行的,因为当 X \mathcal{X} X为高纬的时候会带来计算过载,还可能泄漏用户数据的信息。一个更好的方式是使用 G ∗ : Y → Z G^*: \mathcal{Y} \rightarrow \mathcal{Z} G:YZ去作用于潜在的特征信息,从而避免相关的隐私暴露:
G ∗ = arg ⁡ max ⁡ G : Y → Z E y ∼ p ^ ( y ) E z ∼ G ( z ∣ y ) [ ∑ k = 1 K log ⁡ p ( y ∣ z ; θ k p ) ] (3) G^*=\underset{G: \mathcal{Y} \rightarrow \mathcal{Z}}{\arg \max } \mathbb{E}_{y \sim \hat{p}(y)} \mathbb{E}_{z \sim G(z \mid y)}\left[\sum_{k=1}^K \log p\left(y \mid z ; \boldsymbol{\theta}_k^p\right)\right] \tag3 G=G:YZargmaxEyp^(y)EzG(zy)[k=1Klogp(yz;θkp)](3)
根据上述推理,我们的目标是通过学习条件生成器 G G G进行知识提取,条件生成器 G G G w w w参数化,以优化以下目标:
min ⁡ w J ( w ) : = E y ∼ p ^ ( y ) E z ∼ G w ( z ∣ y ) [ l ( σ ( 1 K ∑ k = 1 K g ( z ; θ k p ) ) , y ) ] (4) \min _{\boldsymbol{w}} J(\boldsymbol{w}):=\mathbb{E}_{y \sim \hat{p}(y)} \mathbb{E}_{z \sim G_{\boldsymbol{w}}(z \mid y)}\left[l\left(\sigma\left(\frac{1}{K} \sum_{k=1}^K g\left(z ; \boldsymbol{\theta}_k^p\right)\right), y\right)\right] \tag4 wminJ(w):=Eyp^(y)EzGw(zy)[l(σ(K1k=1Kg(z;θkp)),y)](4)
其中 g g g σ \sigma σ表示为逻辑输出以及激活函数。这样,给定一系列样本标签,我们只需要使用用户的预测层的参数。具体来说,为了使样本更加多样化,我们创建了噪音向量: ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)

3.2 知识提取

在特征提取之后,我们将学习到的生成器 G w G_w Gw广播给本地用户,以便每个用户模型都可以从 G w G_w Gw中采样,获得特征空间上的增强表示 z ∼ G w ( ⋅ ∣ y ) z \sim G_w(\cdot \mid y) zGw(y)。因此,局部模型 θ k \theta_k θk的目标被改变,以使它对扩增样本产生理想预测的概率最大化:
min ⁡ θ k J ( θ k ) : = L ^ k ( θ k ) + E ^ y ∼ p ^ ( y ) , z ∼ G w ( z ∣ y ) [ l ( h ( z ; θ k p ) ; y ) ] , (5) \min _{\boldsymbol{\theta}_k} J\left(\boldsymbol{\theta}_k\right):=\hat{\mathcal{L}}_k\left(\boldsymbol{\theta}_k\right)+\hat{\mathbb{E}}_{y \sim \hat{p}(y), z \sim G_{\boldsymbol{w}}(z \mid y)}\left[l\left(h\left(z ; \boldsymbol{\theta}_k^p\right) ; y\right)\right], \tag5 θkminJ(θk):=L^k(θk)+E^yp^(y),zGw(zy)[l(h(z;θkp);y)],(5)
其中 L ^ k ( θ k ) : = 1 ∣ D ^ k ∣ ∑ x i ∈ D ^ k [ l ( h ( f ( x i ; θ k f ) ; θ k p ) , c ∗ ( x i ) ) ] \hat{\mathcal{L}}_k\left(\boldsymbol{\theta}_k\right):=\frac{1}{\left|\hat{\mathcal{D}}_k\right|} \sum_{x_i \in \hat{\mathcal{D}}_k}\left[l\left(h\left(f\left(x_i ; \boldsymbol{\theta}_k^f\right) ; \boldsymbol{\theta}_k^p\right), c^*\left(x_i\right)\right)\right] L^k(θk):=D^k1xiD^k[l(h(f(xi;θkf);θkp),c(xi))]表示本地数据集上的相关损失。
相关算法如下:
Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析_第2张图片

四. 代码解析

代码链接点这里
我们首先来看一个客户端的训练:

 def train(self, glob_iter, personalized=False, early_stop=100, regularization=True, verbose=False):
     self.clean_up_counts()
     self.model.train()
     self.generative_model.eval()
     TEACHER_LOSS, DIST_LOSS, LATENT_LOSS = 0, 0, 0
     for epoch in range(self.local_epochs):
         self.model.train()
         for i in range(self.K):
             self.optimizer.zero_grad()
             #### sample from real dataset (un-weighted)
             samples =self.get_next_train_batch(count_labels=True)
             X, y = samples['X'], samples['y']
             self.update_label_counts(samples['labels'], samples['counts'])
             model_result=self.model(X, logit=True)
             user_output_logp = model_result['output']
             predictive_loss=self.loss(user_output_logp, y)

             #### sample y and generate z
             if regularization and epoch < early_stop:
                 generative_alpha=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_alpha)
                 generative_beta=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_beta)
                 ### get generator output(latent representation) of the same label
                 gen_output=self.generative_model(y, latent_layer_idx=self.latent_layer_idx)['output']
                 logit_given_gen=self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True)['logit']
                 target_p=F.softmax(logit_given_gen, dim=1).clone().detach()
                 user_latent_loss= generative_beta * self.ensemble_loss(user_output_logp, target_p)

                 sampled_y=np.random.choice(self.available_labels, self.gen_batch_size)
                 sampled_y=torch.tensor(sampled_y)
                 gen_result=self.generative_model(sampled_y, latent_layer_idx=self.latent_layer_idx)
                 gen_output=gen_result['output'] # latent representation when latent = True, x otherwise
                 user_output_logp =self.model(gen_output, start_layer_idx=self.latent_layer_idx)['output']
                 teacher_loss =  generative_alpha * torch.mean(
                     self.generative_model.crossentropy_loss(user_output_logp, sampled_y)
                 )
                 # this is to further balance oversampled down-sampled synthetic data
                 gen_ratio = self.gen_batch_size / self.batch_size
                 loss=predictive_loss + gen_ratio * teacher_loss + user_latent_loss
                 TEACHER_LOSS+=teacher_loss
                 LATENT_LOSS+=user_latent_loss
             else:
                 #### get loss and perform optimization
                 loss=predictive_loss
             loss.backward()
             self.optimizer.step()#self.local_model)
     # local-model <=== self.model
     self.clone_model_paramenter(self.model.parameters(), self.local_model)
     if personalized:
         self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar)
     self.lr_scheduler.step(glob_iter)
     if regularization and verbose:
         TEACHER_LOSS=TEACHER_LOSS.detach().numpy() / (self.local_epochs * self.K)
         LATENT_LOSS=LATENT_LOSS.detach().numpy() / (self.local_epochs * self.K)
         info='\nUser Teacher Loss={:.4f}'.format(TEACHER_LOSS)
         info+=', Latent Loss={:.4f}'.format(LATENT_LOSS)
         print(info)

我这里给大家画了个损失的计算图:
Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析_第3张图片
大家可以对照着去这部分代码去看。其实很简单,一个表示拥护的预测损失,还有对应的是G产生样本再经过模型的预测层 θ p \theta^p θp进行预测的潜在损失。最后就是针对为教师网络的对应损失,这里使用的是不重复标签(也就是一个类别的y只有一个)。
接下来我们看看服务器的更新:

def train_generator(self, batch_size, epoches=1, latent_layer_idx=-1, verbose=False):
    """
    Learn a generator that find a consensus latent representation z, given a label 'y'.
    :param batch_size:
    :param epoches:
    :param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer.
    :param verbose: print loss information.
    :return: Do not return anything.
    """
    #self.generative_regularizer.train()
    self.label_weights, self.qualified_labels = self.get_label_weights()
    TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0

    def update_generator_(n_iters, student_model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS):
        self.generative_model.train()
        student_model.eval()
        for i in range(n_iters):
            self.generative_optimizer.zero_grad()
            y=np.random.choice(self.qualified_labels, batch_size)
            y_input=torch.LongTensor(y)
            ## feed to generator
            gen_result=self.generative_model(y_input, latent_layer_idx=latent_layer_idx, verbose=True)
            # get approximation of Z( latent) if latent set to True, X( raw image) otherwise
            gen_output, eps=gen_result['output'], gen_result['eps']
            ##### get losses ####
            # decoded = self.generative_regularizer(gen_output)
            # regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps
            diversity_loss=self.generative_model.diversity_loss(eps, gen_output)  # encourage different outputs

            ######### get teacher loss ############
            teacher_loss=0
            teacher_logit=0
            for user_idx, user in enumerate(self.selected_users):
                user.model.eval()
                weight=self.label_weights[y][:, user_idx].reshape(-1, 1)
                expand_weight=np.tile(weight, (1, self.unique_labels))
                user_result_given_gen=user.model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
                user_output_logp_=F.log_softmax(user_result_given_gen['logit'], dim=1)
                teacher_loss_=torch.mean( \
                    self.generative_model.crossentropy_loss(user_output_logp_, y_input) * \
                    torch.tensor(weight, dtype=torch.float32))
                teacher_loss+=teacher_loss_
                teacher_logit+=user_result_given_gen['logit'] * torch.tensor(expand_weight, dtype=torch.float32)

            ######### get student loss ############
            student_output=student_model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
            student_loss=F.kl_div(F.log_softmax(student_output['logit'], dim=1), F.softmax(teacher_logit, dim=1))
            if self.ensemble_beta > 0:
                loss=self.ensemble_alpha * teacher_loss - self.ensemble_beta * student_loss + self.ensemble_eta * diversity_loss
            else:
                loss=self.ensemble_alpha * teacher_loss + self.ensemble_eta * diversity_loss
            loss.backward()
            self.generative_optimizer.step()
            TEACHER_LOSS += self.ensemble_alpha * teacher_loss#(torch.mean(TEACHER_LOSS.double())).item()
            STUDENT_LOSS += self.ensemble_beta * student_loss#(torch.mean(student_loss.double())).item()
            DIVERSITY_LOSS += self.ensemble_eta * diversity_loss#(torch.mean(diversity_loss.double())).item()
        return TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS

    for i in range(epoches):
        TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS=update_generator_(
            self.n_teacher_iters, self.model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)

    TEACHER_LOSS = TEACHER_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
    STUDENT_LOSS = STUDENT_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
    DIVERSITY_LOSS = DIVERSITY_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
    info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, ". \
        format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
    if verbose:
        print(info)
    self.generative_lr_scheduler.step()

Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析_第4张图片
其中多样性损失没有写进去,大家根据代码找到即可。

你可能感兴趣的:(每日一次AI论文阅读,论文阅读,联邦学习,无数据知识蒸馏)