5.【笔记】统计学习方法—EM算法

文章目录

  • 1.EM是什么
    • 1.1 E:求期望
    • 1.2 M:极大
  • 2. 代码

1.EM是什么

EM是含有隐变量的概率模型的极大似然估计回极大后验概率估计的迭代算法。

假设:
P ( Y ∣ θ ) = ∏ [ π p y i ( 1 − p ) 1 − y i + ( 1 − π ) q y i ( 1 − q ) 1 − y i ] P(Y|\theta) = \prod[\pi p^{y_i}(1-p)^{1-y_i}+(1-\pi) q^{y_i}(1-q)^{1-y_i}] P(Yθ)=[πpyi(1p)1yi+(1π)qyi(1q)1yi]

1.1 E:求期望

μ i + 1 = π ( p i ) y i ( 1 − ( p i ) ) 1 − y i π ( p i ) y i ( 1 − ( p i ) ) 1 − y i + ( 1 − π ) ( q i ) y i ( 1 − ( q i ) ) 1 − y i \mu^{i+1}=\frac{\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}}{\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}+(1-\pi) (q^i)^{y_i}(1-(q^i))^{1-y_i}} μi+1=π(pi)yi(1(pi))1yi+(1π)(qi)yi(1(qi))1yiπ(pi)yi(1(pi))1yi

1.2 M:极大

π i + 1 = 1 n ∑ j = 1 n μ j i + 1 \pi^{i+1}=\frac{1}{n}\sum_{j=1}^n\mu^{i+1}_j πi+1=n1j=1nμji+1

p i + 1 = ∑ j = 1 n μ j i + 1 y i ∑ j = 1 n μ j i + 1 p^{i+1}=\frac{\sum_{j=1}^n\mu^{i+1}_jy_i}{\sum_{j=1}^n\mu^{i+1}_j} pi+1=j=1nμji+1j=1nμji+1yi

q i + 1 = ∑ j = 1 n ( 1 − μ j i + 1 y i ) ∑ j = 1 n ( 1 − μ j i + 1 ) q^{i+1}=\frac{\sum_{j=1}^n(1-\mu^{i+1}_jy_i)}{\sum_{j=1}^n(1-\mu^{i+1}_j)} qi+1=j=1n(1μji+1)j=1n(1μji+1yi)

2. 代码

#py3.7
class EM:
    def __init__(self, prob):
        self.pro_A, self.pro_B, self.pro_C = prob
        
    # e_step
    def pmf(self, i):
        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow((1-self.pro_B), 1-data[i])
        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow((1-self.pro_C), 1-data[i])
        return pro_1 / (pro_1 + pro_2)
    
    # m_step
    def fit(self, data):
        count = len(data)
        print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B, self.pro_C))
        for d in range(count):
            _ = yield
            _pmf = [self.pmf(k) for k in range(count)]
            pro_A = 1/ count * sum(_pmf)
            pro_B = sum([_pmf[k]*data[k] for k in range(count)]) / sum([_pmf[k] for k in range(count)])
            pro_C = sum([(1-_pmf[k])*data[k] for k in range(count)]) / sum([(1-_pmf[k]) for k in range(count)])
            print('{}/{}  pro_a:{:.3f}, pro_b:{:.3f}, pro_c:{:.3f}'.format(d+1, count, pro_A, pro_B, pro_C))
            self.pro_A = pro_A
            self.pro_B = pro_B
            self.pro_C = pro_C

你可能感兴趣的:(统计学习方法)