EM算法
EM 算法是Dempster,Laind,Rubin于1977年提出的求参数极大似然估计的一种方法,它可以从非完整数据集中对参数进行MLE估计,是一种非常简单实用的学习算法。这种方法可以广泛地应用于处理缺损数据、截尾数据以及带有噪声等所谓的不完全数据。具体地说,我们可以利用EM算法来填充样本中的缺失数据、发现隐藏变量的值、估计HMM中的参数、估计有限混合分布中的参数以及可以进行无监督聚类等等。
最大期望算法(Expectation Maximization Algorithm,又译为:期望最大化算法),是一种迭代算法,用于含有隐变量(hidden variable)的概率参数模型的最大似然估计或极大后验概率估计。
在统计计算中,最大期望(EM)算法是在概率(probabilistic)模型中寻找参数最大似然估计或者最大后验估计的算法,其中概率模型依赖于无法观测的隐藏变量(Latent Variable)。最大期望经常用在机器学习和计算机视觉的数据聚类(Data Clustering)领域。
最大期望算法经过两个步骤交替进行计算,第一步是计算期望(E),也就是将隐藏变量象能够观测到的一样包含在内从而计算最大似然的期望值;另外一步是最大化(M),也就是最大化在 E 步上找到的最大似然的期望值从而计算参数的最大似然估计。M 步上找到的参数然后用于另外一个 E 步计算,这个过程不断交替进行。
参数初始化
对需要估计的参数进行初始赋值,包括均值、方差、混合系数以及期望。
E-Step计算
利用概率分布公式计算后验概率,即期望。
M-step计算
重新估计参数,包括均值、方差、混合系数并且估计此参数下的期望值。
收敛性判断
将新的与旧的值进行比较,并与设置的阈值进行对比,判断迭代是否结束,若不符合条件,则返回到第2步,重新进行计算,直到收敛符合条件结束。
TEST=[[5,5],[9,1],[8,2],[4,6],[7,3]];
#投出来的结果,前面是正面向上的次数,每组结果后面数字表示反面向上的次数。
#由于每次投币要不选择A或者B,且仅从单个样本数据,无法获知,EM算法的主要目标是:通过大量计算和统计,将数据分离或寻找隐含变量。
print(TEST)#整个考虑可以从正面入手,反面配合。
def P(sA,sB,t1,t2):
#s是初始的概率,t1取正面的个数;t2取反面的个数,以概率密度为准。
PA=Cmn(t1+t2,t1)*(sA**t1)*((1-sA)**t2) #计算A出现的概率。
PB=Cmn(t1+t2,t1)*(sB**t1)*((1-sB)**t2) #计算B出现的概率。
return round(PA/(PA+PB),2) #求出新的概率值,然后根据这个概率值进行后面期望值的计算;且保留2位有效数字。
def fac(n):#阶乘。
f=1
fory in range(2,n+1):
f=f*y
return f
def Cmn(m,n):#先定义cmn后面利用这个概率密度函数,会使用排列组合关系。当Cm,n=Cm,m-n。减少计算量。
s=m-n
ifs
n=s
f=1;t=m
fory in range(0,n):
f=f*t
t=t-1
return f/fac(n)
def CoinAB(oldoA,oldoB): #计算期望值,通过一次期望值的求解,再重新迭代概率值。
UA1=0;UA2=0; #UA1和UA2是A硬币投出的期望值。
t3=0;t4=0;t5=0;t6=0; #t3和t4、t5和t6都是为计算硬币A和B的期望值。
fory in TEST: #遍历所有样本数据。
UA1,UA2=y #取当前值,前面表示正面,后面表示反面。
oA=P(oldoA,oldoB,UA1,UA2) #计算出出A硬币的概率。
print(oA,1-oA)
t3=t3+UA1*round(oA,2) #计算A期望值(针对正面这个事实开始讨论)
t4=t4+UA2*round(oA,2)
t5=t5+UA1*round((1-oA),2) #计算B期望值(针对正面这个事实开始讨论)
t6=t6+UA2*round((1-oA),2)
return round(t3/(t3+t4),2),round(t5/(t5+t6),2) #返回迭代新一轮的A、B的概率。
def EM(oA,oB):
y=0; #计迭代次数。
while(1):
y=y+1
oldoA=oA;oldoB=oB #先存储迭代数据,为了计算收敛值,结束条件。
oA,oB=CoinAB(oldoA,oldoB) #分别赋值,为了下次使用。
print("----y={},oA={},oB={}".format(y,oA,oB))
if (oldoA-oA)**2+(oldoB-oB)**2<0.005:#自己设置收敛条件,目的为终止循环。
break
print("oA=",oA,"oB=",oB)
#oA,oB=eval(input("请输入初始值-oA,oB,逗号隔开:\n"))
oA=0.6
oB=0.4
EM(oA,oB)
大家,加油!