论文地址点这里
联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的。在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的数据。
在本文中,我们提出了一个基于无数据的知识蒸馏算法——FEDGEN。具体来说,FEDGEN学习一个仅从用户模型的预测规则派生的生成模型,给定一个目标标签,可以生成与用户预测集合一致的特征表示。该生成器随后广播给用户,在潜在空间上护送他们的模型训练与增强样本,这体现了来自其他同行用户的蒸馏知识。给定一个维数远远小于输入空间的潜在空间,FEDGEN学习到的生成器可以是轻量级的,为当前FL框架引入最小的开销。
我们使用 X ⊂ R p \mathcal{X} \subset \mathbb{R}^p X⊂Rp 表示为输入的实例空间 Z ⊂ R d \mathcal{Z} \subset \mathbb{R}^d Z⊂Rd 表示为潜在的特征空间,其中 d < p d d<p
联邦学习 致力于学习一个全局的模型参数 θ \theta θ,其能在每个客户端上都能达到最小的损失:
min θ E T k ∈ T [ L k ( θ ) ] (1) \min _{\boldsymbol{\theta}} \mathbb{E}_{\mathcal{T}_k \in \mathcal{T}}\left[\mathcal{L}_k(\boldsymbol{\theta})\right] \tag1 θminETk∈T[Lk(θ)](1)
其中 T = { T k } k = 1 K \mathcal{T}=\left\{\mathcal{T}_k\right\}_{k=1}^K T={Tk}k=1K表示所有的客户端的任务。 我们假设所有的任务共享相同的标签规则 c ∗ c^* c∗以及损失函数 T k = ⟨ D k , c ∗ ⟩ \mathcal{T}_k=\left\langle\mathcal{D}_k, c^*\right\rangle Tk=⟨Dk,c∗⟩。在实际中, 等式1可以进行这样的优化: min θ 1 K ∑ k = 1 K L ^ k ( θ ) \min _\theta \frac{1}{K} \sum_{k=1}^K \hat{\mathcal{L}}_k(\boldsymbol{\theta}) minθK1∑k=1KL^k(θ), 其中 L ^ k ( θ ) : = 1 ∣ D ^ k ∣ ∑ x i ∈ D ^ k [ l ( h ( f ( x i ; θ f ) ; θ p ) , c ∗ ( x i ) ) ] \hat{\mathcal{L}}_k(\boldsymbol{\theta}):=\frac{1}{\left|\hat{\mathcal{D}}_k\right|} \sum_{x_i \in \hat{\mathcal{D}}_k}\left[l\left(h\left(f\left(x_i ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right), c^*\left(x_i\right)\right)\right] L^k(θ):=∣D^k∣1∑xi∈D^k[l(h(f(xi;θf);θp),c∗(xi))]表示在数据集 D ^ k \hat{\mathcal{D}}_k D^k上的经验损失。 这里有个隐含的假设是:对于全局数据 D ^ \hat{\mathcal{D}} D^是分布在所有的客户端上: D ^ = ∪ { D ^ k } k = 1 K \hat{\mathcal{D}}=\cup\left\{\hat{\mathcal{D}}_k\right\}_{k=1}^K D^=∪{D^k}k=1K。
知识蒸馏(KD) 也被称为师生范式,其目标是学习轻量级的学生模型,使用从一个或多个强大的老师那里提取的知识。典型的KD利用一个代理数据集“pd”来最小化分别来自教师模型 θ T \theta^T θT和学生模型 θ S \theta^S θS的logits输出之间的差异。一个代表性的选择是使用Kullback-Leibler散度来衡量这种差异:
min θ S E x ∼ D ^ P [ D K L [ σ ( g ( f ( x ; θ T f ) ; θ T p ) ∥ σ ( g ( f ( x ; θ S f ) ; θ S p ) ] ] \min _{\boldsymbol{\theta}_S} \mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{P}}}\left[D _ { \mathrm { KL } } \left[\sigma\left(g\left(f\left(x ; \boldsymbol{\theta}_T^f\right) ; \boldsymbol{\theta}_T^p\right) \| \sigma\left(g\left(f\left(x ; \boldsymbol{\theta}_S^f\right) ; \boldsymbol{\theta}_S^p\right)\right]\right]\right.\right. θSminEx∼D^P[DKL[σ(g(f(x;θTf);θTp)∥σ(g(f(x;θSf);θSp)]]
其中 g ( ⋅ ) g(\cdot) g(⋅)表示为模型的逻辑输出, σ ( ⋅ ) \sigma(\cdot) σ(⋅)表示为激活函数。
KD的想法已经扩展到FL,以解决用户异质性,通过将每个用户模型 θ k \theta_k θk视为教师,其信息聚合到学生(全局)模型 θ \theta θ,以提高其泛化性能:
min θ E x ∼ D ^ P [ D K L [ σ ( 1 K ∑ k = 1 K g ( f ( x ; θ k f ) ; θ k p ) ) ∥ σ ( g ( f ( x ; θ f ) ; θ p ) ] ] \min _{\boldsymbol{\theta}} \mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{P}}}\left[D_{\mathrm{KL}}\left[\sigma\left(\frac{1}{K} \sum_{k=1}^K g\left(f\left(x ; \boldsymbol{\theta}_k^f\right) ; \boldsymbol{\theta}_k^p\right)\right) \| \sigma\left(g\left(f\left(x ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right)\right]\right]\right. θminEx∼D^P[DKL[σ(K1k=1∑Kg(f(x;θkf);θkp))∥σ(g(f(x;θf);θp)]]
上述方法的一个主要限制在于它依赖于代理数据集“ D ^ P \hat{D}_P D^P”,需要仔细考虑它的选择,并在蒸馏性能中起着关键作用。接下来,我们将展示如何以无数据的方式使KD对FL可行。
我们的核心思想是提取关于数据分布的全局视图的知识,这些知识是传统FL无法观察到的,并将这些知识提取到局部模型中,以指导它们的学习。我们首先考虑学习一个条件分布 Q ∗ : Y → X Q^*: \mathcal{Y} \rightarrow \mathcal{X} Q∗:Y→X来描述这种知识,它与真实数据分布一致:
Q ∗ = arg max Q : Y → X E y ∼ p ( y ) E x ∼ Q ( x ∣ y ) [ log p ( y ∣ x ) ] , (2) Q^*=\underset{Q: \mathcal{Y} \rightarrow \mathcal{X}}{\arg \max } \mathbb{E}_{y \sim p(y)} \mathbb{E}_{x \sim Q(x \mid y)}[\log p(y \mid x)], \tag2 Q∗=Q:Y→XargmaxEy∼p(y)Ex∼Q(x∣y)[logp(y∣x)],(2)
其中 p ( y ) p(y) p(y)表示标签的先验概率而 p ( y ∣ x ) p(y|x) p(y∣x)表示为后验概率。为了能优化式2,我们替换了 p ( y ) p(y) p(y)以及 p ( y ∣ x ) p(y|x) p(y∣x)。首先,我们估计 p ( y ) p(y) p(y)为:
p ^ ( y ) ∝ ∑ k E x ∼ D ^ k [ I ( c ∗ ( x ) = y ) ] , \hat{p}(y) \propto \sum_k \mathbb{E}_{x \sim \hat{\mathcal{D}}_k}\left[\mathrm{I}\left(c^*(x)=y\right)\right], p^(y)∝k∑Ex∼D^k[I(c∗(x)=y)],
其中 I ( ⋅ ) \mathrm{I}(\cdot) I(⋅)为一个指标函数。在实际中, p ^ ( y ) \hat{p}(y) p^(y)可以使用各个客户端训练标签的数量来进行统计。下一步,我们使用各个客户端的集成知识估计 p ( y ∣ x ) p(y|x) p(y∣x):
log p ^ ( y ∣ x ) ∝ 1 K ∑ k = 1 K log p ( y ∣ x ; θ k ) \log \hat{p}(y \mid x) \propto \frac{1}{K} \sum_{k=1}^K \log p\left(y \mid x ; \boldsymbol{\theta}_k\right) logp^(y∣x)∝K1k=1∑Klogp(y∣x;θk)
有了上面的近似之后,直接在输入空间 X \mathcal{X} X上优化式子(2)依然是不行的,因为当 X \mathcal{X} X为高纬的时候会带来计算过载,还可能泄漏用户数据的信息。一个更好的方式是使用 G ∗ : Y → Z G^*: \mathcal{Y} \rightarrow \mathcal{Z} G∗:Y→Z去作用于潜在的特征信息,从而避免相关的隐私暴露:
G ∗ = arg max G : Y → Z E y ∼ p ^ ( y ) E z ∼ G ( z ∣ y ) [ ∑ k = 1 K log p ( y ∣ z ; θ k p ) ] (3) G^*=\underset{G: \mathcal{Y} \rightarrow \mathcal{Z}}{\arg \max } \mathbb{E}_{y \sim \hat{p}(y)} \mathbb{E}_{z \sim G(z \mid y)}\left[\sum_{k=1}^K \log p\left(y \mid z ; \boldsymbol{\theta}_k^p\right)\right] \tag3 G∗=G:Y→ZargmaxEy∼p^(y)Ez∼G(z∣y)[k=1∑Klogp(y∣z;θkp)](3)
根据上述推理,我们的目标是通过学习条件生成器 G G G进行知识提取,条件生成器 G G G由 w w w参数化,以优化以下目标:
min w J ( w ) : = E y ∼ p ^ ( y ) E z ∼ G w ( z ∣ y ) [ l ( σ ( 1 K ∑ k = 1 K g ( z ; θ k p ) ) , y ) ] (4) \min _{\boldsymbol{w}} J(\boldsymbol{w}):=\mathbb{E}_{y \sim \hat{p}(y)} \mathbb{E}_{z \sim G_{\boldsymbol{w}}(z \mid y)}\left[l\left(\sigma\left(\frac{1}{K} \sum_{k=1}^K g\left(z ; \boldsymbol{\theta}_k^p\right)\right), y\right)\right] \tag4 wminJ(w):=Ey∼p^(y)Ez∼Gw(z∣y)[l(σ(K1k=1∑Kg(z;θkp)),y)](4)
其中 g g g和 σ \sigma σ表示为逻辑输出以及激活函数。这样,给定一系列样本标签,我们只需要使用用户的预测层的参数。具体来说,为了使样本更加多样化,我们创建了噪音向量: ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵ∼N(0,I)。
在特征提取之后,我们将学习到的生成器 G w G_w Gw广播给本地用户,以便每个用户模型都可以从 G w G_w Gw中采样,获得特征空间上的增强表示 z ∼ G w ( ⋅ ∣ y ) z \sim G_w(\cdot \mid y) z∼Gw(⋅∣y)。因此,局部模型 θ k \theta_k θk的目标被改变,以使它对扩增样本产生理想预测的概率最大化:
min θ k J ( θ k ) : = L ^ k ( θ k ) + E ^ y ∼ p ^ ( y ) , z ∼ G w ( z ∣ y ) [ l ( h ( z ; θ k p ) ; y ) ] , (5) \min _{\boldsymbol{\theta}_k} J\left(\boldsymbol{\theta}_k\right):=\hat{\mathcal{L}}_k\left(\boldsymbol{\theta}_k\right)+\hat{\mathbb{E}}_{y \sim \hat{p}(y), z \sim G_{\boldsymbol{w}}(z \mid y)}\left[l\left(h\left(z ; \boldsymbol{\theta}_k^p\right) ; y\right)\right], \tag5 θkminJ(θk):=L^k(θk)+E^y∼p^(y),z∼Gw(z∣y)[l(h(z;θkp);y)],(5)
其中 L ^ k ( θ k ) : = 1 ∣ D ^ k ∣ ∑ x i ∈ D ^ k [ l ( h ( f ( x i ; θ k f ) ; θ k p ) , c ∗ ( x i ) ) ] \hat{\mathcal{L}}_k\left(\boldsymbol{\theta}_k\right):=\frac{1}{\left|\hat{\mathcal{D}}_k\right|} \sum_{x_i \in \hat{\mathcal{D}}_k}\left[l\left(h\left(f\left(x_i ; \boldsymbol{\theta}_k^f\right) ; \boldsymbol{\theta}_k^p\right), c^*\left(x_i\right)\right)\right] L^k(θk):=∣D^k∣1∑xi∈D^k[l(h(f(xi;θkf);θkp),c∗(xi))]表示本地数据集上的相关损失。
相关算法如下:
代码链接点这里
我们首先来看一个客户端的训练:
def train(self, glob_iter, personalized=False, early_stop=100, regularization=True, verbose=False):
self.clean_up_counts()
self.model.train()
self.generative_model.eval()
TEACHER_LOSS, DIST_LOSS, LATENT_LOSS = 0, 0, 0
for epoch in range(self.local_epochs):
self.model.train()
for i in range(self.K):
self.optimizer.zero_grad()
#### sample from real dataset (un-weighted)
samples =self.get_next_train_batch(count_labels=True)
X, y = samples['X'], samples['y']
self.update_label_counts(samples['labels'], samples['counts'])
model_result=self.model(X, logit=True)
user_output_logp = model_result['output']
predictive_loss=self.loss(user_output_logp, y)
#### sample y and generate z
if regularization and epoch < early_stop:
generative_alpha=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_alpha)
generative_beta=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_beta)
### get generator output(latent representation) of the same label
gen_output=self.generative_model(y, latent_layer_idx=self.latent_layer_idx)['output']
logit_given_gen=self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True)['logit']
target_p=F.softmax(logit_given_gen, dim=1).clone().detach()
user_latent_loss= generative_beta * self.ensemble_loss(user_output_logp, target_p)
sampled_y=np.random.choice(self.available_labels, self.gen_batch_size)
sampled_y=torch.tensor(sampled_y)
gen_result=self.generative_model(sampled_y, latent_layer_idx=self.latent_layer_idx)
gen_output=gen_result['output'] # latent representation when latent = True, x otherwise
user_output_logp =self.model(gen_output, start_layer_idx=self.latent_layer_idx)['output']
teacher_loss = generative_alpha * torch.mean(
self.generative_model.crossentropy_loss(user_output_logp, sampled_y)
)
# this is to further balance oversampled down-sampled synthetic data
gen_ratio = self.gen_batch_size / self.batch_size
loss=predictive_loss + gen_ratio * teacher_loss + user_latent_loss
TEACHER_LOSS+=teacher_loss
LATENT_LOSS+=user_latent_loss
else:
#### get loss and perform optimization
loss=predictive_loss
loss.backward()
self.optimizer.step()#self.local_model)
# local-model <=== self.model
self.clone_model_paramenter(self.model.parameters(), self.local_model)
if personalized:
self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar)
self.lr_scheduler.step(glob_iter)
if regularization and verbose:
TEACHER_LOSS=TEACHER_LOSS.detach().numpy() / (self.local_epochs * self.K)
LATENT_LOSS=LATENT_LOSS.detach().numpy() / (self.local_epochs * self.K)
info='\nUser Teacher Loss={:.4f}'.format(TEACHER_LOSS)
info+=', Latent Loss={:.4f}'.format(LATENT_LOSS)
print(info)
我这里给大家画了个损失的计算图:
大家可以对照着去这部分代码去看。其实很简单,一个表示拥护的预测损失,还有对应的是G产生样本再经过模型的预测层 θ p \theta^p θp进行预测的潜在损失。最后就是针对为教师网络的对应损失,这里使用的是不重复标签(也就是一个类别的y只有一个)。
接下来我们看看服务器的更新:
def train_generator(self, batch_size, epoches=1, latent_layer_idx=-1, verbose=False):
"""
Learn a generator that find a consensus latent representation z, given a label 'y'.
:param batch_size:
:param epoches:
:param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer.
:param verbose: print loss information.
:return: Do not return anything.
"""
#self.generative_regularizer.train()
self.label_weights, self.qualified_labels = self.get_label_weights()
TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0
def update_generator_(n_iters, student_model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS):
self.generative_model.train()
student_model.eval()
for i in range(n_iters):
self.generative_optimizer.zero_grad()
y=np.random.choice(self.qualified_labels, batch_size)
y_input=torch.LongTensor(y)
## feed to generator
gen_result=self.generative_model(y_input, latent_layer_idx=latent_layer_idx, verbose=True)
# get approximation of Z( latent) if latent set to True, X( raw image) otherwise
gen_output, eps=gen_result['output'], gen_result['eps']
##### get losses ####
# decoded = self.generative_regularizer(gen_output)
# regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps
diversity_loss=self.generative_model.diversity_loss(eps, gen_output) # encourage different outputs
######### get teacher loss ############
teacher_loss=0
teacher_logit=0
for user_idx, user in enumerate(self.selected_users):
user.model.eval()
weight=self.label_weights[y][:, user_idx].reshape(-1, 1)
expand_weight=np.tile(weight, (1, self.unique_labels))
user_result_given_gen=user.model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
user_output_logp_=F.log_softmax(user_result_given_gen['logit'], dim=1)
teacher_loss_=torch.mean( \
self.generative_model.crossentropy_loss(user_output_logp_, y_input) * \
torch.tensor(weight, dtype=torch.float32))
teacher_loss+=teacher_loss_
teacher_logit+=user_result_given_gen['logit'] * torch.tensor(expand_weight, dtype=torch.float32)
######### get student loss ############
student_output=student_model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
student_loss=F.kl_div(F.log_softmax(student_output['logit'], dim=1), F.softmax(teacher_logit, dim=1))
if self.ensemble_beta > 0:
loss=self.ensemble_alpha * teacher_loss - self.ensemble_beta * student_loss + self.ensemble_eta * diversity_loss
else:
loss=self.ensemble_alpha * teacher_loss + self.ensemble_eta * diversity_loss
loss.backward()
self.generative_optimizer.step()
TEACHER_LOSS += self.ensemble_alpha * teacher_loss#(torch.mean(TEACHER_LOSS.double())).item()
STUDENT_LOSS += self.ensemble_beta * student_loss#(torch.mean(student_loss.double())).item()
DIVERSITY_LOSS += self.ensemble_eta * diversity_loss#(torch.mean(diversity_loss.double())).item()
return TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS
for i in range(epoches):
TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS=update_generator_(
self.n_teacher_iters, self.model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
TEACHER_LOSS = TEACHER_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
STUDENT_LOSS = STUDENT_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
DIVERSITY_LOSS = DIVERSITY_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, ". \
format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
if verbose:
print(info)
self.generative_lr_scheduler.step()