【diffusers 进阶(十二)】Lora 具体是怎么加入模型的(推理代码篇下)OminiControl

  • 【diffusers 极速入门(一)】pipeline 实际调用的是什么? call 方法!
  • 【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数
  • 【diffusers极速入门(三)】生成的图像尺寸与 UNet 和 VAE 之间的关系
  • 【diffusers极速入门(四)】EMA 操作是什么?
  • 【diffusers极速入门(五)】扩散模型中的 Scheduler(noise_scheduler)的作用是什么?
  • 【diffusers极速入门(六)】缓存梯度和自动放缩学习率以及代码详解
  • 【diffusers极速入门(七)】Classifier-Free Guidance (CFG)直观理解以及对应代码
  • 【diffusers极速入门(八)】GPU 显存节省(减少内存使用)技巧总结
  • 【diffusers极速入门(九)】GPU 显存节省(减少内存使用)代码总结
  • 【diffusers极速入门(十)】Flux-pipe 推理,完美利用任何显存大小,GPU显存节省终极方案(附代码)
  • 【diffusers 进阶(十一)】Lora 具体是怎么加入模型的(推理代码篇上)OminiControl

上一篇解释了 load_lora 和 set_adapters 如何加载 lora 并将其设置到模型中的具体 component 和 module 中,本文则具体看具体模型中的 enable_lora 是如何实现的。


文章目录

  • 1. Flux 的三层代码结构
  • 2. enable_lora 的上下文管理
    • `enable_lora` 类的工作原理
    • 原代码分析
    • 总结
  • 3.具体的 enable_lora 实施


1. Flux 的三层代码结构

这里简单说一下 flux 的三层代码结构:
(1)整个 transformer backbone,即 tranformer_forward 函数,代码位置 /path/OminiControl/src/flux/transformer.py;
(2)分为两种 MM-DiT(19个) 和 Single-DiT(38个),分别是 block_forwardsingle_block_forward 函数,代码位置 /path/OminiControl/src/flux/block.py;
(3)每种 DiT block 里都有 attention,是 attn_forward 函数,代码位置 /path/OminiControl/src/flux/block.py;

下面是以 SD3 的图为例子,所以只有 MM-DiT block,但 Flux 也是类似的这样的 3 层结构。

2. enable_lora 的上下文管理

我们进入 tranformer_forward 中,会看到以下代码,意思为针对 condition_latents 的 x_embedder 施加 lora,而对 hidden_states 的 x_embedder 并不施加 lora。

    with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
        hidden_states = self.x_embedder(hidden_states)
    condition_latents = self.x_embedder(condition_latents) if use_condition else None

根据提供的 enable_lora 类实现,我可以更准确地分析上述的代码片段。

enable_lora 类的工作原理

class enable_lora:
    def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
        self.activated: bool = activated
        if activated:
            return
        self.lora_modules: List[BaseTunerLayer] = [
            each for each in lora_modules if isinstance(each, BaseTunerLayer)
        ]
        self.scales = [
            {
                active_adapter: lora_module.scaling[active_adapter]
                for active_adapter in lora_module.active_adapters
            }
            for lora_module in self.lora_modules
        ]

    def __enter__(self) -> None:
        if self.activated:
            return

        for lora_module in self.lora_modules:
            if not isinstance(lora_module, BaseTunerLayer):
                continue
            lora_module.scale_layer(0)

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[Any],
    ) -> None:
        if self.activated:
            return
        for i, lora_module in enumerate(self.lora_modules):
            if not isinstance(lora_module, BaseTunerLayer):
                continue
            for active_adapter in lora_module.active_adapters:
                lora_module.scaling[active_adapter] = self.scales[i][active_adapter]

这个上下文管理器的核心功能是:

  1. 初始化时

    • 记录 activated 状态
    • 如果 activatedTrue,直接返回(不做任何操作)
    • 如果 activatedFalse,保存所有 LoRA 模块的当前缩放因子
  2. 进入上下文时__enter__):

    • 如果 activatedTrue,不做任何操作
    • 如果 activatedFalse,将所有 LoRA 模块的缩放因子设置为 0(即禁用 LoRA)
  3. 退出上下文时__exit__):

    • 如果 activatedTrue,不做任何操作
    • 如果 activatedFalse,恢复所有 LoRA 模块的原始缩放因子

原代码分析

回到原始代码:

with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
    hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None

现在我们可以确定:

  1. 如果 model_config 是空字典 {}

    • model_config.get("latent_lora", False) 返回 False
    • enable_lora((self.x_embedder,), False)临时禁用 self.x_embedder 中的 LoRA
  2. 在上下文内部:

    • hidden_states = self.x_embedder(hidden_states) 执行时,LoRA 被禁用(缩放因子为0)
  3. 在上下文外部:

    • condition_latents = self.x_embedder(condition_latents) 执行时,LoRA 已恢复到原始状态

总结

对于空字典 model_config 的情况:

  • hidden_states 的处理:禁用 LoRA(因为 activated=False,上下文管理器将缩放因子设为0)
  • condition_latents 的处理:启用 LoRA(因为在上下文外,LoRA 已恢复原始状态)

所以,这段代码的效果是:

  • 处理 hidden_states不使用 LoRA
  • 处理 condition_latents使用 LoRA(如果原本有启用的话)

实际上是只对 condition_latents 应用 LoRA,而对 hidden_states 禁用了 LoRA。

3.具体的 enable_lora 实施

一共 7 个,tranformer_forward 中有 1 个,block_forward 中(MM-DiT中)有 2 个,single_block_forward (Single-DiT中)有2个,attn_forward 中有 2 个。

但具体加入了 lora 的模块只有 self.norm1.linear,self.norm.linear 和 attn.to_q,attn.to_k,具体可以观察每步中该层 load 的模块。

# tranformer_forward 中有 1 个
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
    hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None

# block_forward 中(MM-DiT中)有以下 2 个
with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
        hidden_states, emb=temb
    )
if use_cond:
    (
        norm_condition_latents,
        cond_gate_msa,
        cond_shift_mlp,
        cond_scale_mlp,
        cond_gate_mlp,
    ) = self.norm1(condition_latents, emb=cond_temb)

# Feed-forward.
with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
    # 1. hidden_states
    ff_output = self.ff(norm_hidden_states)
    ff_output = gate_mlp.unsqueeze(1) * ff_output
# 2. encoder_hidden_states
context_ff_output = self.ff_context(norm_encoder_hidden_states)
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
# 3. condition_latents
if use_cond:
    cond_ff_output = self.ff(norm_condition_latents)
    cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output

# single_block_forward(Single-DiT中)有2个
    with enable_lora(
        (
            self.norm.linear,
            self.proj_mlp,
        ),
        model_config.get("latent_lora", False),
    ):
        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
    if using_cond:
        residual_cond = condition_latents
        norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
        mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))

    with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
        hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
        gate = gate.unsqueeze(1)
        hidden_states = gate * self.proj_out(hidden_states)
        hidden_states = residual + hidden_states
    if using_cond:
        condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
        cond_gate = cond_gate.unsqueeze(1)
        condition_latents = cond_gate * self.proj_out(condition_latents)
        condition_latents = residual_cond + condition_latents

# attn_forward 中的 2 个
with enable_lora(
    (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
):
    # `sample` projections.
    query = attn.to_q(hidden_states)
    key = attn.to_k(hidden_states)
    value = attn.to_v(hidden_states)
# 
 with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
     # linear proj
     hidden_states = attn.to_out[0](hidden_states)
     # dropout
     hidden_states = attn.to_out[1](hidden_states)
 if condition_latents is not None:
      condition_latents = attn.to_out[0](condition_latents)
      condition_latents = attn.to_out[1](condition_latents)

你可能感兴趣的:(编程学习,AIGC,Diffusion,python,AIGC,人工智能,stable,diffusion)