机器学习入门算法及其java实现-EM(Expectation Maxium)算法

1、算法基本原理:

  • EM算法一般用于存在隐变量或潜在变量的概率模型,可以算是一种含有隐的概率模型参数的极大似然估计法;
  • 假设 θ 为模型的参数, 为模型的观测数据, γ 模型中存在的隐藏变量,EM算法的是通过最大化观测数据 logP(Y|θ) 的方法来求出 θ 的极大似然估计,可以转化为表达式: θ^=argmaxθ(logP(Y|θ))
  • 经过转化,可以将问题转化为最大化 E(γ) 的问题,即 θ^=argmaxγ(E(γ))

2、算法推导过程:

  • 根据极大似然法的原理,我们的目标是极大化观测数据 Y 关于参数 θ 的对数似然函数,即:
    L(θ)=logP(Y|θ)=logγP(Y,γ|θ)
    =log(λP(Y|γ,θ)P(Z|θ))
  • 因为EM算法是通过迭代的办法逐步接近极大 L(θ) 的,假设在第 i 次迭代后 θi ,此我们希望能够使 L(θ)L(θ(i))0
    L(θ)L(θi)=log(γP(Y|γ,θ)P(γ|θ))log(P(Y|θi)
    =log(γP(γ|Y,θi)P(Y|γ,θ)P(γ|θ)P(γ|Y,θi))logP(Y|θi)
    γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi))logP(Y|θi)
    =γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi))
    B(θ,θi)=L(θi)+γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi))
    L(θ)B(θ,θi)
    即函数 B(θ,θi) L(θ) 的一个下界,并且易知: L(θi)B(θi,θi) ,因此可以使 B(θ,θi) 增大的 θ 也可以使 L(θ) 增大,为了使 L(θ) 有尽可能大的增大,选择 θi+1 使 B(θ,θi) 打到极大,即:
    θ(i+1)=argmaxθB(θ,θi)
    上式可以改写为:
    θ(i+1)=argmaxθ(L(θi)+γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi)))
    =argmaxθγP(γ|Y),θilog(P(Y|γ,θ)P(γ|θ))
    =argmaxθγP(γ|Y,θi)log(P(Y,γ|θ))
    =argmaxθQ(θ,θi)
    3、EM算法收敛性证明: 根据对数函数函数性质:若 P(Y|θi) 单调递增且收敛到某一值则 Q(θ,θi) 收敛。 单调性:
    P(Y|θ)=P(Y,γ|θ)P(γ|Y,θ)
    logP(Y|θ)=logP(Y,γ|θ)logP(γ|Y,θ)
    Q(θ,θi)=γlogP(Y,γ|θ)P(γ|Y,θi)
    H(θ,θi)=γlogP(γ|Y,θ)P(γ|Y,θi)
    于是对数似然函数可以写成:
    logP(Y|θ)=Q(θ,θi)H(θ,θi)
    上式中 θ 分别取为 θi θi+1 并相减,有:
    logP(Y|θi+1)logP(Y|θi)
    =[Q(θi+1,θi)Q(θi,θi)][H(θi+1,θi)H(θi,θi)]
    因为 θi+1 使Q(\theta,\theta^{i})达到极大,所以有:
    Q(θi+1,θi)Q(θi,θi)0
    其第2项,可以推导得出:
    H(θi+1,θi)H(θi,θi)
    =γ(logp(γ|Y,θi+1)P(γ|Y,θi))P(γ|Y,θi)
    log(γP(γ|Y,θi+1)P(γ|Y,θi)P(γ|Y,θi))
    =log(P(γ|Y,θi+1))=0
    又因为 P(Y|θi) 有界,所以 L(θi)=log(P(Y|θi)) 收敛到某一值 L

4、算法步骤:

  • 选择参数的初值 θ0 ,开始迭代;
  • E步:记 θi 为第 i 次迭代参数 θ 的估计值,在第 i 次迭代的E步,计算:
    Q(θ,θi)=Eγ[logP(Y,γ|θ)|Y,θ]
    =γlog(P(Y,γ|θ)P(γ|Y,θi))
  • M步:求使 Q(θ,θi) 极大化的 θ ,确定第 i+1 次迭代的参数的估计值 θi+1
    θi+1=argmaxθQ(θ,θi)

    -重复第E步和第M步,直到对于较小的正数 ξ1 ξ2 ,若满足 :
    ||θi+1θi||ξq
    ||Q(θi+1,θi)Q(θi,θi)||ξ2
    则停止迭代。
package binorandom;

public class binomain {

    public static void main(String[] args) {
        int[] b=new int[1000];
        for (int i=0;i<1000;i++){
        b[i]=binorandom.getBinomial(1, 0.4);
        }
        int[] a=new int[1000];
        for ( int i=0;i<999;i++){
            if (b[i]==1){
                a[i]=binorandom.getBinomial(1,0.5);
            }
            if(b[i]==0){
                a[i]=binorandom.getBinomial(1,0.6);
            }
            System.out.print(a[i]+" ");
        }
        System.out.print(a[999]);
    }

}


package binorandom;

public class binorandom {
    public static int getBinomial(int n, double p) {
         int x = 0;
         for(int i = 0; i < n; i++) {
         if(Math.random() < p)
          x++;
         }
         return x;
        }
}


//生成数据集合

package EMpackage;
import java.util.Scanner;
public class EMmain {
    public static void main(String[] args){
        System.out.println("请输入观测值个数");
        Scanner input=new Scanner(System.in);
        int datanumber=input.nextInt();
        System.out.println("请输入观测值(0或者1):");
        Scanner input1=new Scanner(System.in);
        int[] obdata=new int[datanumber];
        for(int i=0; iout.println("您输入的是:"+" ");
        for (int b=0;b1;b++){
            System.out.print(obdata[b]+" ");
        }
        System.out.println(obdata[datanumber-1]);
        double[] original=new double[3];
        original=ori.original();
        double eq=ori.eq();
        System.out.println("初始条件为:"+" "+original[0]+" "+original[1]+" "+original[2]);
        System.out.println("停止条件为:"+" "+eq);
        input1.close();
        input.close();
        double[] original1=new double[3];
        original1=EM.original1(original, obdata, datanumber);   
        int x=0;
        while (euclid(minus(original1,original))>eq){
        original=original1;
        original1=EM.original1(original,obdata,datanumber);
        x=x+1;
        }
        System.out.println("pi="+original1[0]+"\n"+"p="+original1[1]+"\n"+"q="+original1[2]+"\n"+x);
    }

private static double euclid(double[] x) {
    double sum=0;
    for (int i=0;i<3;i++){
        sum=sum+Math.pow(x[i], 2);
    }
    double euclid=Math.sqrt(sum);
    return euclid;
}

private static double[] minus(double[] x,double[] y) {
    double[] temp=new double[3];
    for (int i=0;i<3;i++){
        temp[i]=x[i]-y[i];
    }
    return temp;
  }
}


package EMpackage;
public class EM {
    public static double[] original1(double[] original,int[] obdata,int datanumber){
        double[] ybl=new double[datanumber];
        double[] uybl=new double[datanumber];
        double[] l=new double[datanumber];
        double datanumber1=datanumber;
        for (int i=0;i0]*Math.pow(original[1],obdata[i] )*Math.pow(1-original[1],1-obdata[i] ))/(original[0]*Math.pow(original[1],obdata[i])*Math.pow((1-original[1]),(1-obdata[i]))+(1-original[0])*Math.pow(original[2],obdata[i])*Math.pow((1-original[2]),(1-obdata[i])));
            uybl[i]=obdata[i]*(original[0]*Math.pow(original[1],obdata[i] )*Math.pow(1-original[1],1-obdata[i] ))/(original[0]*Math.pow(original[1],obdata[i])*Math.pow((1-original[1]),(1-obdata[i]))+(1-original[0])*Math.pow(original[2],obdata[i])*Math.pow((1-original[2]),(1-obdata[i])));
            l[i]=1;
        }
        double[] original1=new double[3];
        original1[0]=(1/datanumber1)*sum(ybl,datanumber);
        original1[1]=(sum(uybl,datanumber)/sum(ybl,datanumber));
        original1[2]=(sum(ybl,datanumber)-sum(uybl,datanumber))/(sum(l,datanumber)-sum(ybl,datanumber));
        return original1;   
    }

    private static double sum(double[] ybl,int datanumber) {
        double sum=0;
        for (int i=0;ireturn sum;
    }
}


package EMpackage;

import java.util.Scanner;

public class ori{ 
    public static double[] original(){
        System.out.println("请输入初始条件条件:"+" ");

    Scanner input=new Scanner(System.in);
    double original[]=new double[3];
    for(int d=0; d<3;d++){ 
        original[d]=input.nextDouble();
        }
    return original;
    }
    public static double eq(){
        System.out.println("请输入停止条件:"+" ");
        Scanner input=new Scanner(System.in);
        double eq=input.nextDouble();
        return eq;
    }
 }
//EM算法主程序

实验结果及实例分析
机器学习入门算法及其java实现-EM(Expectation Maxium)算法_第1张图片
多次运算结果对比:
原始系数pi,p,q(0.4、0.5、0.6):

初始迭代系数 (0.5、0.5、0.5) (0.4、0.4、0.4) (0.4、0.4、0.5) (0.5、0.4、0.6) (0.4、0.5、0.4) (0.5、0.4、0.5) (0.5、0.6、0.4)
运算结果 (0.5、0.73、0.32) (0.54、0.84、0.19) (0.55、0.84、0.19) (0.56、0.77、0.29) (0.56、0.85、0.19) (0.56、0.77、0.30) (0.56、0.76、0.30)

原始系数pi,p,q(0.5、0.5、0.5):

初始迭代系数 (0.4、0.4、0.4) (0.3、0.4、0.4) (0.4、0.4、0.5) (0.4、0.4、0.6) (0.4、0.5、0.4) (0.5、0.4、0.3) (0.5、0.6、0.4)
运算结果 (0.49、0.7、0.23) (0.49、0.85、0.14) (0.5、0.76、0.23) (0.49、0.75、0.24) (0.49、0.76、0.23) (0.49、1.02、-0.02) (0.5、0.53、0.45)

从以上两表不难看出EM算法受到初始迭代值的影响十分大,但是其优点在于需要的迭代次数少,收敛速度十分迅速。

你可能感兴趣的:(机器学习十大算法,分类算法)