统计学习方法 第九章 EM算法

EM算法是用于解决含有隐变量的概率模型参数的极大似然估计,每次迭代由两步组成,E步求期望,M步求极大。而对于高斯混合模型(GMM)上EM算法也是一个有效的解决方法。
现以习题9.1为例,简单实现一下EM算法

import numpy as np

#计算E步
def expectationCal(pi,p,q,y):
    u = pi * (p ** y) * ((1 - p) ** (1 - y))
    v = (1 - pi) * (q ** y) * ((1 - q) ** (1 - y))
    u = np.divide(u,u + v)
    return u

#计算M步
def maximumCal(u,y):
    n = np.shape(y)[0]
    uSum = np.sum(u)
    pi = np.full(n,uSum / n)
    p = np.sum(u * y) / uSum
    p = np.full(n,p)
    q = np.sum((1 - u) * y) / np.sum(1 - u)
    q = np.full(n,q)
    return pi,p,q

#参数计算
def EMCal(yLabel,pi,p,q,iter = 40):
    y = np.array(yLabel)
    n = np.shape(yLabel)[0]
    pi_old = np.full(n,pi)
    p_old = np.full(n,p)
    q_old = np.full(n,q)
    for i in range(iter):
        u = expectationCal(pi_old,p_old,q_old,y)
        pi_new,p_new,q_new = maximumCal(u,y)
        if np.abs(np.sum(pi_new - pi_old)) <= 0.01 or np.abs(np.sum(p_new - p_old)) <= 0.01 or np.abs(np.sum(q_new - q_old)) <= 0.01:
            break
        else:
            pi_old = pi_new
            p_old = p_new
            q_old = q_new
            
    return pi_new,p_new,q_new

输入以下数值:

yLabel = [1,1,0,1,0,0,1,0,1,1]
pi = 0.46
p = 0.55
q = 0.67

可以得知pi_new = 0.4619,p_new = 0.5346,q_new = 0.6561

你可能感兴趣的:(统计学习方法,机器学习,python,算法)