掩码图像建模 (MIM) 中的对数似然与交叉熵

掩码图像建模 (MIM) 中的对数似然与交叉熵

1. 问题背景

在掩码图像建模(MIM)任务中,模型需要预测被遮蔽的图像块对应的视觉词元(可以理解为图像块的离散类别标签)。

具体来说:

  • 每个被遮蔽的图像块 i ∈ M i \in M iM 的真实标签是 z i z_i zi(即它原本的视觉词元类别)。
  • 模型通过 Transformer 编码器生成隐藏向量 h L i h_L^i hLi,然后通过一个分类器(参数为 W c , b c W_c, b_c Wc,bc)预测该位置的概率分布 p MIM ( z ′ ∣ x M ) p_{\text{MIM}}(z' | x^M) pMIM(zxM)

2. Softmax 分类器的作用

分类器的公式是:
p MIM ( z ′ ∣ x M ) = softmax z ( W c h L i + b c ) p_{\text{MIM}}(z' | x^M) = \text{softmax}_z(W_c h_L^i + b_c) pMIM(zxM)=softmaxz(WchLi+bc)

  • 输入:隐藏向量 h L i ∈ R D h_L^i \in \mathbb{R}^D hLiRD(来自 Transformer 的输出)。
  • 参数:权重矩阵 W c ∈ R ∣ V ∣ × D W_c \in \mathbb{R}^{|\mathcal{V}| \times D} WcRV×D 和偏置 b c ∈ R ∣ V ∣ b_c \in \mathbb{R}^{|\mathcal{V}|} bcRV,其中 ∣ V ∣ |\mathcal{V}| V 是视觉词元的总类别数。
  • 输出:一个概率分布,表示模型认为被遮蔽块 i i i 属于每个视觉词元类别的概率。

具体计算步骤

  1. 对每个被遮蔽位置 i i i,计算线性变换: W c h L i + b c W_c h_L^i + b_c WchLi+bc,得到一个长度为 ∣ V ∣ |\mathcal{V}| V 的向量(称为logits)。
  2. 对 logits 应用 softmax 函数,将其转换为概率分布:
    p ( z ′ ) = exp ⁡ ( logits [ z ′ ] ) ∑ k = 1 ∣ V ∣ exp ⁡ ( logits [ k ] ) p(z') = \frac{\exp(\text{logits}[z'])}{\sum_{k=1}^{|\mathcal{V}|} \exp(\text{logits}[k])} p(z)=k=1Vexp(logits[k])exp(logits[z])
    其中 z ′ z' z 是某个可能的视觉词元类别。

3. 最大化对数似然(Maximize Log-Likelihood)

目标:让模型对真实标签 z i z_i zi 的预测概率尽可能高。

数学表达:
max ⁡ θ E x ∼ D [ ∑ i ∈ M log ⁡ p MIM ( z i ∣ x M ) ] \max_{\theta} \mathbb{E}_{x \sim \mathcal{D}} \left[ \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right] θmaxExD[iMlogpMIM(zixM)]

  • 解释
    • 对每个被遮蔽位置 i i i,计算真实标签 z i z_i zi 的对数概率 log ⁡ p MIM ( z i ∣ x M ) \log p_{\text{MIM}}(z_i | x^M) logpMIM(zixM)
    • 对所有被遮蔽位置求和,再对所有训练样本 x x x 求期望。
    • 目标是最大化这个总和,即让模型对真实标签的预测概率尽可能大。

4. 交叉熵损失(Cross-Entropy Loss)

交叉熵损失是分类任务中常用的损失函数,定义为:
L CE = − ∑ i ∈ M log ⁡ p MIM ( z i ∣ x M ) \mathcal{L}_{\text{CE}} = - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) LCE=iMlogpMIM(zixM)

  • 解释
    • 对每个被遮蔽位置 i i i,计算真实标签 z i z_i zi 的负对数概率。
    • 对所有被遮蔽位置求和,得到总损失。
    • 目标是最小化这个损失,即让真实标签的预测概率尽可能高。

5. 最大化对数似然 vs. 最小化交叉熵

关键结论
最大化对数似然最小化交叉熵损失完全等价的!

具体来说:
max ⁡ θ ∑ i ∈ M log ⁡ p MIM ( z i ∣ x M )    ⟺    min ⁡ θ ( − ∑ i ∈ M log ⁡ p MIM ( z i ∣ x M ) ) \max_{\theta} \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \quad \iff \quad \min_{\theta} \left( - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right) θmaxiMlogpMIM(zixM)θmin(iMlogpMIM(zixM))

  • 左边是最大化对数似然(使正确标签的概率最大化)。
  • 右边是最小化交叉熵损失(使正确标签的负对数概率最小化)。

6. 为什么等价?

  • 数学本质:交叉熵损失是负的对数似然。
    • 对数似然是 ∑ log ⁡ p \sum \log p logp,交叉熵是 − ∑ log ⁡ p -\sum \log p logp
    • 最大化 A A A 等价于最小化 − A -A A
  • 直观理解
    • 如果模型对真实标签的预测概率 p ( z i ) p(z_i) p(zi) 越大,对数似然 log ⁡ p ( z i ) \log p(z_i) logp(zi) 越大,交叉熵损失 − log ⁡ p ( z i ) -\log p(z_i) logp(zi) 越小。
    • 例如,若真实标签的概率 p ( z i ) = 0.9 p(z_i) = 0.9 p(zi)=0.9,则交叉熵损失为 − log ⁡ ( 0.9 ) ≈ 0.11 -\log(0.9) \approx 0.11 log(0.9)0.11
      若概率 p ( z i ) = 0.1 p(z_i) = 0.1 p(zi)=0.1,则损失为 − log ⁡ ( 0.1 ) ≈ 2.30 -\log(0.1) \approx 2.30 log(0.1)2.30
      显然,概率越大,损失越小。

7. 实际训练中的计算

在代码中,通常直接使用交叉熵损失函数(如 PyTorch 的 CrossEntropyLoss):

# 假设 logits 是模型的输出(未经过 softmax)
# targets 是被遮蔽位置的真实视觉词元标签
loss = F.cross_entropy(logits, targets)
  • 内部过程
    1. 对 logits 应用 softmax,得到概率分布。
    2. 计算真实标签的负对数概率。
    3. 对所有样本和位置求平均,得到最终损失。

总结

  • 目标:让模型对真实标签的预测概率尽可能高。
  • 数学实现:通过最大化对数似然(等价于最小化交叉熵损失)。
  • 代码实现:直接使用交叉熵损失函数,无需手动计算对数似然。

你可能感兴趣的:(深度学习,LLM,人工智能,深度学习,计算机视觉)