EM算法的详解和样本集实例数学过程讲解,可以详见:https://blog.csdn.net/u012421852/article/details/79915908
# -*- coding: utf-8 -*-
"""
@author: 蔚蓝的天空Tom
Aim:实现EM算法(Expectation Maximization Algorithm)
"""
import numpy as np
class CEM(object):
def __init__(self, samples, pa, pb, threshold):
self.samples = samples
self.pa = pa
self.pb = pb
self.eStepRet = None
self.mStepRet = None
self.threshold = threshold
self.work()
def likelihood_func(self, samples, p):
'''似然函数'''
ret = []
for e in samples:
ret.append(np.power(p, list(e).count(1))*np.power(1-p, list(e).count(0)))
return ret
def e_step(self):
'''
计算在模型参数pa,pb下观察数据来自投掷硬币a/b的概率
'''
#计算每轮投掷coin a和coin b的似然函数值(即每个样本发生概率的似然值)
likelihooda = self.likelihood_func(self.samples, self.pa)
#[0.00079626239999999997, 0.0005308416000000002, 0.0005308416000000002, 0.0005308416000000002, 0.0011943936]
likelihoodb = self.likelihood_func(self.samples, self.pb)
#[0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625]
#计算每轮投掷来自coin a和coin b的概率
self.eStepRet = np.array([e/sum(e) for e in zip(likelihooda, likelihoodb)])
print('eStepRet:\n', self.eStepRet)
#[[ 0.44914893 0.55085107]
# [ 0.35215613 0.64784387]
# [ 0.35215613 0.64784387]
# [ 0.35215613 0.64784387]
# [ 0.55016939 0.44983061]]
return
def m_step(self):
'''计算模型参数pa, pb的新估计值
'''
old_pa, old_pb = self.pa, self.pb
print('old pa:', old_pa, 'old pb:', old_pb)
h_a, t_a = 0, 0
h_b, t_b = 0, 0
for sample, e in zip(self.samples, self.eStepRet):
h_a += list(sample).count(1) * e[0]
t_a += list(sample).count(0) * e[0]
h_b += list(sample).count(1) * e[1]
t_b += list(sample).count(0) * e[1]
self.pa = h_a / (h_a + t_a)
self.pb = h_b / (h_b + t_b)
print('new pa:', self.pa, 'new pb:', self.pb)
gap_pa, gap_pb = self.pa - old_pa, self.pb - old_pb
print('gap_pa:', gap_pa, 'gap_pb:', gap_pb)
return gap_pa < self.threshold and gap_pb < self.threshold
def work(self):
self.e_step()
stop = self.m_step()
if (stop != True):
return self.work()
print('stop em\n')
return
def GetResult(self):
return self.pa, self.pb
pass
def CEM_manual():
samples = np.array([[1,0,1,0,1,0,1,0,1,0], #coin a, 5+5-
[1,0,1,0,1,0,1,0,1,1], #coin b, 6+4-
[1,1,1,0,1,0,1,0,1,0], #coin a, 6+4-
[1,0,1,1,1,0,1,0,1,0], #coin b, 6+4-
[1,0,1,0,1,0,0,1,0,0]])#coin a, 4+6-
samples = np.array([[1,0,1,1,1,0,1,0,1,0], #coin a, 5+5-
[1,0,1,1,1,0,1,0,1,1], #coin b, 6+4-
[1,1,1,1,1,0,1,0,1,0], #coin a, 6+4-
[1,0,1,1,1,1,1,0,1,0], #coin b, 6+4-
[1,0,1,1,1,0,0,1,0,0]])#coin a, 4+6-
#可以知道
#p(1|a) = (5+6+4)/30 = 0.5
#p(1|b) = (6+6)/20 = 0.6
#设置初始值
pa, pb = 0.4, 0.5#p(1|a) = 0.4, p(1|b) = 0.5
threshold = 0.00001
em = CEM(samples, pa, pb, threshold)
ret = em.GetResult()
print(ret)
return
if __name__=='__main__':
CEM_manual()
runfile('C:/Users/l13277/EM.py', wdir='C:/Users/l13277')
eStepRet:
[[ 0.35215613 0.64784387]
[ 0.26599464 0.73400536]
[ 0.26599464 0.73400536]
[ 0.26599464 0.73400536]
[ 0.44914893 0.55085107]]
old pa: 0.4 old pb: 0.5
new pa: 0.621811879589 new pb: 0.648553523107
gap_pa: 0.221811879589 gap_pb: 0.148553523107
eStepRet:
[[ 0.51017251 0.48982749]
[ 0.4813223 0.5186777 ]
[ 0.4813223 0.5186777 ]
[ 0.4813223 0.5186777 ]
[ 0.53895512 0.46104488]]
old pa: 0.621811879589 old pb: 0.648553523107
new pa: 0.63630074059 new pb: 0.643678879642
gap_pa: 0.0144888610009 gap_pb: -0.00487464346478
eStepRet:
[[ 0.50320194 0.49679806]
[ 0.49519623 0.50480377]
[ 0.49519623 0.50480377]
[ 0.49519623 0.50480377]
[ 0.51120602 0.48879398]]
old pa: 0.63630074059 old pb: 0.643678879642
new pa: 0.638975359185 new pb: 0.64102463807
gap_pa: 0.00267461859504 gap_pb: -0.00265424157212
eStepRet:
[[ 0.50088945 0.49911055]
[ 0.49866584 0.50133416]
[ 0.49866584 0.50133416]
[ 0.49866584 0.50133416]
[ 0.50311303 0.49688697]]
old pa: 0.638975359185 old pb: 0.64102463807
new pa: 0.639715379851 new pb: 0.640284620153
gap_pa: 0.000740020666399 gap_pb: -0.00074001791737
eStepRet:
[[ 0.50024707 0.49975293]
[ 0.4996294 0.5003706 ]
[ 0.4996294 0.5003706 ]
[ 0.4996294 0.5003706 ]
[ 0.50086473 0.49913527]]
old pa: 0.639715379851 old pb: 0.640284620153
new pa: 0.639920938888 new pb: 0.640079061112
gap_pa: 0.000205559037088 gap_pb: -0.000205559040588
eStepRet:
[[ 0.50006863 0.49993137]
[ 0.49989706 0.50010294]
[ 0.49989706 0.50010294]
[ 0.49989706 0.50010294]
[ 0.5002402 0.4997598 ]]
old pa: 0.639920938888 old pb: 0.640079061112
new pa: 0.639978038581 new pb: 0.640021961419
gap_pa: 5.7099692811e-05 gap_pb: -5.70996928321e-05
eStepRet:
[[ 0.50001906 0.49998094]
[ 0.4999714 0.5000286 ]
[ 0.4999714 0.5000286 ]
[ 0.4999714 0.5000286 ]
[ 0.50006672 0.49993328]]
old pa: 0.639978038581 old pb: 0.640021961419
new pa: 0.639993899606 new pb: 0.640006100394
gap_pa: 1.58610249222e-05 gap_pb: -1.58610249225e-05
eStepRet:
[[ 0.5000053 0.4999947 ]
[ 0.49999206 0.50000794]
[ 0.49999206 0.50000794]
[ 0.49999206 0.50000794]
[ 0.50001853 0.49998147]]
old pa: 0.639993899606 old pb: 0.640006100394
new pa: 0.639998305446 new pb: 0.640001694554
gap_pa: 4.40584023775e-06 gap_pb: -4.40584023775e-06
stop em
(0.63999830544606295, 0.64000169455393696)
(end)