如何从一个任意的均值 μ 方差 σ^2 的高斯分布中采样得到噪声xt

1)可以首先从一个标准高斯分布(均值0,方差1)中进行采样得到噪声 ε

noise = torch.randn_like(x_0)

2)然后利用 μ + σ·ε 就等价于从任意均值 μ 方差 σ^2 的高斯分布中采样(首先从标准高斯分布中采样得到噪声 ε,接着乘以标准差再加上均值)。公式表示如下:

xt = sqrt(1-betas[t]) * xt-1 + sqrt(betas[t]) * noise

完整代码:

# https://pytorch.org/docs/stable/generated/torch.randn_like.html
betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
noise = torch.randn_like(x_0)
xt = sqrt(1-betas[t]) * xt-1 + sqrt(betas[t]) * noise

例子:

该函数torch.randn生成一个张量,其元素取自零均值和单位方差的高斯分布。乘以得到sqrt(0.1)所需的方差。

x = torch.zeros(5, 10, 20, dtype=torch.float64)
x = x + (0.1**0.5)*torch.randn(5, 10, 20)

你可能感兴趣的:(均值算法,深度学习,pytorch)