混合高斯模型指的是将多个高斯分布进行加权叠加,在数学上就是将多个不同的高斯分布的概率密度函数进行加权叠加,形成一个新的概率密度函数表达式,用于描述当前情形下的样本分布:
p ( x ) = ∑ k = 1 K α k N ( μ k , Σ k ) , 其中 ∑ k = 1 K α k = 1 p(x)=\sum_{k=1}^K\alpha_kN(\mu_k,\Sigma_k),其中\sum_{k=1}^K\alpha_k=1 p(x)=k=1∑KαkN(μk,Σk),其中k=1∑Kαk=1
混合高斯模型中含有隐变量,就是说样本不知道是由哪一个高斯模型产生。因此求解混合高斯模型需要用到EM算法。EM算法需要求解的参数是:
θ = ( p 1 , p 2 , ⋯ , p K , μ 1 , μ 2 , ⋯ , μ K , Σ 1 , Σ 2 , ⋯ , Σ K ) \theta=(p_1,p_2,\cdots,p_K,\mu_1,\mu_2,\cdots,\mu_K,\Sigma_1,\Sigma_2,\cdots,\Sigma_K) θ=(p1,p2,⋯,pK,μ1,μ2,⋯,μK,Σ1,Σ2,⋯,ΣK)
经过EM算法得到 θ \theta θ的迭代结果。
μ k t + 1 = ∑ i = 1 N P ( z i = C k ∣ x i , θ ( t ) ) x i ∑ i = 1 N P ( z i = C k ∣ x i , θ ( t ) ) Σ k t + 1 = ∑ i = 1 N P ( z i = C k ∣ x i , θ ( t ) ) ( x i − μ k ( t ) ) ( x i − μ k ( t ) ) T ∑ i = 1 N P ( z i = C k ∣ x i , θ ( t ) ) P ( z i ∣ x i , θ ( t ) = p z i ( t ) N ( x i ∣ μ z i ( t ) , Σ z i ( t ) ) ∑ k = 1 K p k ( t ) N ( x i ∣ μ k ( t ) , Σ k ( t ) ) ) \begin{aligned} &\mu_k^{t+1}=\frac{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})x_i}{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})} \\ \! \\ &\Sigma_k^{t+1}=\frac{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})(x_i-\mu_k^{(t)})(x_i-\mu_k^{(t)})^T}{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})} \\ \! \\ &P(z_i|x_i,\theta^{(t)}=\frac{p_{z_i}^{(t)}N(x_i|\mu_{z_i}^{(t)},\Sigma_{z_i}^{(t)})}{\sum_{k=1}^{K}p_k^{(t)}N(x_i|\mu_k^{(t)},\Sigma_k^{(t)})}) \end{aligned} μkt+1=∑i=1NP(zi=Ck∣xi,θ(t))∑i=1NP(zi=Ck∣xi,θ(t))xiΣkt+1=∑i=1NP(zi=Ck∣xi,θ(t))∑i=1NP(zi=Ck∣xi,θ(t))(xi−μk(t))(xi−μk(t))TP(zi∣xi,θ(t)=∑k=1Kpk(t)N(xi∣μk(t),Σk(t))pzi(t)N(xi∣μzi(t),Σzi(t)))
# -*- coding: utf-8 -*-
# @Use :
# @Time : 2022/8/27 21:30
# @FileName: MixGaussian.py
# @Software: PyCharm
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets._samples_generator import make_blobs
X, y_true = make_blobs(n_samples=1000, centers=4)
fig, ax = plt.subplots(1, 2, sharex='row')
ax[0].scatter(X[:, 0], X[:, 1], s=5, alpha=0.5)
gmm = GaussianMixture(n_components=4)
gmm.fit(X)
print(gmm.weights_) # 权重
print(gmm.means_) # 均值
print(gmm.covariances_) # 协方差
print(gmm.predict_proba(X)[:10].round(5))
labels = gmm.predict(X)
ax[1].scatter(X[:, 0], X[:, 1], s=5, alpha=0.5, c=labels, cmap='viridis')
ax[1].grid(ls='--')
plt.show()