一句话定义:
条件变分自编码器(CVAE)是一种生成模型,能够根据给定的条件信息(如标签、文本描述)生成符合特定要求的数据(如图像、文本)。
类比理解:
假设你想让画家画一只“戴墨镜的猫”。传统画家(类似普通VAE)自由发挥,而CVAE是“命题画家”——必须按你的要求创作,且能生成多种风格的结果(如卡通猫、写实猫)。
CVAE的目标是学习条件分布 p ( x ∣ y ) p(x|y) p(x∣y)(给定条件 y y y 生成数据 x x x)。通过引入潜在变量 z z z,将问题分解为:
p ( x ∣ y ) = ∫ p ( x ∣ z , y ) p ( z ∣ y ) d z p(x|y) = \int p(x|z, y) p(z|y) dz p(x∣y)=∫p(x∣z,y)p(z∣y)dz
由于直接计算积分困难,CVAE使用变分推断近似求解。
CVAE通过最大化证据下界(ELBO)来训练模型:
log p ( x ∣ y ) ≥ E z ∼ q [ log p ( x ∣ z , y ) ] − D K L ( q ( z ∣ x , y ) ∥ p ( z ∣ y ) ) \log p(x|y) \geq \mathbb{E}_{z \sim q}[\log p(x|z, y)] - D_{KL}(q(z|x, y) \| p(z|y)) logp(x∣y)≥Ez∼q[logp(x∣z,y)]−DKL(q(z∣x,y)∥p(z∣y))
[0,0,0,1,0,0,0,0,0,0]
。假设编码器网络输出(输出4个数):
μ = [ 0.4 , − 0.2 ] , σ = [ 0.1 , 0.3 ] \mu = [0.4, -0.2], \quad \sigma = [0.1, 0.3] μ=[0.4,−0.2],σ=[0.1,0.3]
即潜在变量分布为:
q ( z ∣ x , y ) = N ( [ 0.4 , − 0.2 ] , [ 0.1 , 0.3 ] ) q(z|x, y) = \mathcal{N}([0.4, -0.2], [0.1, 0.3]) q(z∣x,y)=N([0.4,−0.2],[0.1,0.3])
使用重参数化技巧采样:
z = μ + σ ⊙ ϵ , ϵ ∼ N ( 0 , 1 ) z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1) z=μ+σ⊙ϵ,ϵ∼N(0,1)
假设(\epsilon = [1.0, -0.333]):
z 1 = 0.4 + 0.1 × 1.0 = 0.5 z 2 = − 0.2 + 0.3 × ( − 0.333 ) ≈ − 0.3 z = [ 0.5 , − 0.3 ] z_1 = 0.4 + 0.1 \times 1.0 = 0.5 \\ z_2 = -0.2 + 0.3 \times (-0.333) \approx -0.3 \\ z = [0.5, -0.3] z1=0.4+0.1×1.0=0.5z2=−0.2+0.3×(−0.333)≈−0.3z=[0.5,−0.3]
知识点分类:深度学习批处理原理 + 变分自编码器训练细节
在深度学习中,批量训练(Batch Training) 是标准实践:
案例设定:
[0,0,0,1,0,0,0,0,0,0]
)。每个样本的输入为 concat(图像, 标签)
:
批量输入矩阵:
X batch = [ x 1 ( 1 ) x 1 ( 2 ) ⋯ x 1 ( 784 ) y 1 ( 1 ) ⋯ y 1 ( 10 ) x 2 ( 1 ) x 2 ( 2 ) ⋯ x 2 ( 784 ) y 2 ( 1 ) ⋯ y 2 ( 10 ) x 3 ( 1 ) x 3 ( 2 ) ⋯ x 3 ( 784 ) y 3 ( 1 ) ⋯ y 3 ( 10 ) ] ∈ R 3 × 794 X_{\text{batch}} = \begin{bmatrix} x_1^{(1)} & x_1^{(2)} & \cdots & x_1^{(784)} & y_1^{(1)} & \cdots & y_1^{(10)} \\ x_2^{(1)} & x_2^{(2)} & \cdots & x_2^{(784)} & y_2^{(1)} & \cdots & y_2^{(10)} \\ x_3^{(1)} & x_3^{(2)} & \cdots & x_3^{(784)} & y_3^{(1)} & \cdots & y_3^{(10)} \end{bmatrix} \in \mathbb{R}^{3 \times 794} Xbatch= x1(1)x2(1)x3(1)x1(2)x2(2)x3(2)⋯⋯⋯x1(784)x2(784)x3(784)y1(1)y2(1)y3(1)⋯⋯⋯y1(10)y2(10)y3(10) ∈R3×794
示例输出(假设编码器网络计算得到):
样本 | μ1 | μ2 | logσ1² | logσ2² |
---|---|---|---|---|
1 | 0.4 | -0.2 | -4.605 | -2.407 |
2 | 0.5 | -0.1 | -3.912 | -1.897 |
3 | 0.3 | -0.25 | -4.199 | -2.120 |
转换为实际方差:
σ i = e log σ i 2 \sigma_i = \sqrt{e^{\log\sigma_i^2}} σi=elogσi2
样本 | σ1 | σ2 |
---|---|---|
1 | 0.1 | 0.3 |
2 | 0.2 | 0.4 |
3 | 0.15 | 0.35 |
使用重参数化技巧对每个样本独立采样:
z = μ + σ ⊙ ϵ , ϵ ∼ N ( 0 , 1 ) z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0,1) z=μ+σ⊙ϵ,ϵ∼N(0,1)
示例噪声 ϵ \epsilon ϵ(随机生成):
样本 | ε1 | ε2 |
---|---|---|
1 | 1.0 | -0.333 |
2 | -0.5 | 0.8 |
3 | 0.3 | -0.2 |
计算z:
潜在变量矩阵:
Z batch = [ 0.5 − 0.3 0.4 0.22 0.345 − 0.32 ] ∈ R 3 × 2 Z_{\text{batch}} = \begin{bmatrix} 0.5 & -0.3 \\ 0.4 & 0.22 \\ 0.345 & -0.32 \end{bmatrix} \in \mathbb{R}^{3 \times 2} Zbatch= 0.50.40.345−0.30.22−0.32 ∈R3×2
将每个样本的 z z z 与标签拼接后输入解码器:
concat([0.5, -0.3], y_1) → 12维
。concat([0.4, 0.22], y_2) → 12维
。concat([0.345, -0.32], y_3) → 12维
。解码器输出:
每个样本生成784维像素向量,批量输出矩阵:
X ^ batch = [ x ^ 1 ( 1 ) x ^ 1 ( 2 ) ⋯ x ^ 1 ( 784 ) x ^ 2 ( 1 ) x ^ 2 ( 2 ) ⋯ x ^ 2 ( 784 ) x ^ 3 ( 1 ) x ^ 3 ( 2 ) ⋯ x ^ 3 ( 784 ) ] ∈ R 3 × 784 \hat{X}_{\text{batch}} = \begin{bmatrix} \hat{x}_1^{(1)} & \hat{x}_1^{(2)} & \cdots & \hat{x}_1^{(784)} \\ \hat{x}_2^{(1)} & \hat{x}_2^{(2)} & \cdots & \hat{x}_2^{(784)} \\ \hat{x}_3^{(1)} & \hat{x}_3^{(2)} & \cdots & \hat{x}_3^{(784)} \end{bmatrix} \in \mathbb{R}^{3 \times 784} X^batch= x^1(1)x^2(1)x^3(1)x^1(2)x^2(2)x^3(2)⋯⋯⋯x^1(784)x^2(784)x^3(784) ∈R3×784
重构损失(MSE):对每个样本计算像素级误差后取平均。
KL散度:对每个样本独立计算后取平均。
总损失(假设未加权重):
Loss batch = 0.07 + 2.217 = 2.287 \text{Loss}_{\text{batch}} = 0.07 + 2.217 = 2.287 Lossbatch=0.07+2.217=2.287
反向传播:梯度累积与平均
- 梯度来源:每个样本的损失函数对参数的梯度会被独立计算。例如:
- 样本 x 1 x_1 x1 对 W 1 W_1 W1 的梯度为 ∇ W 1 L 1 \nabla_{W_1} \mathcal{L}_1 ∇W1L1。
- 样本 x 2 x_2 x2 对 W 1 W_1 W1 的梯度为 ∇ W 1 L 2 \nabla_{W_1} \mathcal{L}_2 ∇W1L2。
- 样本 x 3 x_3 x3 对 W 1 W_1 W1 的梯度为 ∇ W 1 L 3 \nabla_{W_1} \mathcal{L}_3 ∇W1L3。
- 梯度平均:参数的实际更新梯度是批量内所有样本梯度的平均值: ∇ W 1 batch = 1 3 ( ∇ W 1 L 1 + ∇ W 1 L 2 + ∇ W 1 L 3 ) \nabla_{W_1}^{\text{batch}} = \frac{1}{3} \left( \nabla_{W_1} \mathcal{L}_1 + \nabla_{W_1} \mathcal{L}_2 + \nabla_{W_1} \mathcal{L}_3 \right) ∇W1batch=31(∇W1L1+∇W1L2+∇W1L3)
- 参数更新:优化器(如Adam)根据平均梯度调整参数: W 1 ← W 1 − η ⋅ ∇ W 1 batch W_1 \leftarrow W_1 - \eta \cdot \nabla_{W_1}^{\text{batch}} W1←W1−η⋅∇W1batch 其中 η \eta η 是学习率。
附:批量训练伪代码
# 假设 batch_size=3, image_dim=784, latent_dim=2
def train_batch(x_batch, y_batch):
# 编码器输入拼接
encoder_input = torch.cat([x_batch, y_batch], dim=1) # shape: (3, 794)
# 编码器输出μ和logσ²
mu_logvar = encoder(encoder_input) # shape: (3, 4)
mu, logvar = mu_logvar[:, :2], mu_logvar[:, 2:]
# 重参数化采样z
eps = torch.randn_like(mu)
z = mu + torch.exp(0.5 * logvar) * eps # shape: (3, 2)
# 解码器输入拼接
decoder_input = torch.cat([z, y_batch], dim=1) # shape: (3, 12)
# 生成图像
x_recon = decoder(decoder_input) # shape: (3, 784)
# 计算损失
recon_loss = F.mse_loss(x_recon, x_batch, reduction='mean') # 批量平均
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
total_loss = recon_loss + kl_loss
# 反向传播
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
特性 | VAE | CVAE |
---|---|---|
生成自由度 | 完全自由 | 受条件约束 |
输入依赖 | 仅数据 x x x | 数据 x x x + 条件 y y y |
应用场景 | 无约束生成(如插值) | 条件生成(如文本到图像) |
CVAE通过条件注入和变分推断,在生成模型中实现了可控性与多样性的平衡。其核心价值在于:
相关代码实现:
# 伪代码示例
class CVAE(nn.Module):
def __init__(self):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(794, 512), nn.ReLU(),
nn.Linear(512, 256), nn.ReLU(),
nn.Linear(256, 4) # 输出μ和logσ²
)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(12, 256), nn.ReLU(),
nn.Linear(256, 512), nn.ReLU(),
nn.Linear(512, 784), nn.Sigmoid()
)
def forward(self, x, y):
# 编码器
mu_logvar = self.encoder(torch.cat([x, y], dim=1))
mu, logvar = mu_logvar[:, :2], mu_logvar[:, 2:]
# 重参数化采样
z = mu + torch.exp(0.5*logvar) * torch.randn_like(mu)
# 解码器
x_recon = self.decoder(torch.cat([z, y], dim=1))
return x_recon, mu, logvar