上一篇解释了 load_lora 和 set_adapters 如何加载 lora 并将其设置到模型中的具体 component 和 module 中,本文则具体看具体模型中的 enable_lora 是如何实现的。
这里简单说一下 flux 的三层代码结构:
(1)整个 transformer backbone,即 tranformer_forward
函数,代码位置 /path/OminiControl/src/flux/transformer.py;
(2)分为两种 MM-DiT(19个) 和 Single-DiT(38个),分别是 block_forward
和 single_block_forward
函数,代码位置 /path/OminiControl/src/flux/block.py;
(3)每种 DiT block 里都有 attention,是 attn_forward
函数,代码位置 /path/OminiControl/src/flux/block.py;
下面是以 SD3 的图为例子,所以只有 MM-DiT block,但 Flux 也是类似的这样的 3 层结构。
我们进入 tranformer_forward 中,会看到以下代码,意思为针对 condition_latents 的 x_embedder 施加 lora,而对 hidden_states 的 x_embedder 并不施加 lora。
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None
根据提供的 enable_lora
类实现,我可以更准确地分析上述的代码片段。
enable_lora
类的工作原理class enable_lora:
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
self.activated: bool = activated
if activated:
return
self.lora_modules: List[BaseTunerLayer] = [
each for each in lora_modules if isinstance(each, BaseTunerLayer)
]
self.scales = [
{
active_adapter: lora_module.scaling[active_adapter]
for active_adapter in lora_module.active_adapters
}
for lora_module in self.lora_modules
]
def __enter__(self) -> None:
if self.activated:
return
for lora_module in self.lora_modules:
if not isinstance(lora_module, BaseTunerLayer):
continue
lora_module.scale_layer(0)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
if self.activated:
return
for i, lora_module in enumerate(self.lora_modules):
if not isinstance(lora_module, BaseTunerLayer):
continue
for active_adapter in lora_module.active_adapters:
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
这个上下文管理器的核心功能是:
初始化时:
activated
状态activated
为 True
,直接返回(不做任何操作)activated
为 False
,保存所有 LoRA 模块的当前缩放因子进入上下文时(__enter__
):
activated
为 True
,不做任何操作activated
为 False
,将所有 LoRA 模块的缩放因子设置为 0(即禁用 LoRA)退出上下文时(__exit__
):
activated
为 True
,不做任何操作activated
为 False
,恢复所有 LoRA 模块的原始缩放因子回到原始代码:
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None
现在我们可以确定:
如果 model_config
是空字典 {}
:
model_config.get("latent_lora", False)
返回 False
enable_lora((self.x_embedder,), False)
会临时禁用 self.x_embedder
中的 LoRA在上下文内部:
hidden_states = self.x_embedder(hidden_states)
执行时,LoRA 被禁用(缩放因子为0)在上下文外部:
condition_latents = self.x_embedder(condition_latents)
执行时,LoRA 已恢复到原始状态对于空字典 model_config
的情况:
hidden_states
的处理:禁用 LoRA(因为 activated=False
,上下文管理器将缩放因子设为0)condition_latents
的处理:启用 LoRA(因为在上下文外,LoRA 已恢复原始状态)所以,这段代码的效果是:
hidden_states
时不使用 LoRAcondition_latents
时使用 LoRA(如果原本有启用的话)实际上是只对 condition_latents
应用 LoRA,而对 hidden_states
禁用了 LoRA。
一共 7 个,tranformer_forward 中有 1 个,block_forward 中(MM-DiT中)有 2 个,single_block_forward (Single-DiT中)有2个,attn_forward 中有 2 个。
但具体加入了 lora 的模块只有 self.norm1.linear,self.norm.linear 和 attn.to_q,attn.to_k,具体可以观察每步中该层 load 的模块。
# tranformer_forward 中有 1 个
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None
# block_forward 中(MM-DiT中)有以下 2 个
with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb
)
if use_cond:
(
norm_condition_latents,
cond_gate_msa,
cond_shift_mlp,
cond_scale_mlp,
cond_gate_mlp,
) = self.norm1(condition_latents, emb=cond_temb)
# Feed-forward.
with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
# 1. hidden_states
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
# 2. encoder_hidden_states
context_ff_output = self.ff_context(norm_encoder_hidden_states)
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
# 3. condition_latents
if use_cond:
cond_ff_output = self.ff(norm_condition_latents)
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
# single_block_forward(Single-DiT中)有2个
with enable_lora(
(
self.norm.linear,
self.proj_mlp,
),
model_config.get("latent_lora", False),
):
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
if using_cond:
residual_cond = condition_latents
norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if using_cond:
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
cond_gate = cond_gate.unsqueeze(1)
condition_latents = cond_gate * self.proj_out(condition_latents)
condition_latents = residual_cond + condition_latents
# attn_forward 中的 2 个
with enable_lora(
(attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
):
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
#
with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if condition_latents is not None:
condition_latents = attn.to_out[0](condition_latents)
condition_latents = attn.to_out[1](condition_latents)