【转】向前-向后算法(forward-bac…

学习问题

在HMM模型中,已知隐藏状态的集合S,观察值的集合O,以及一个观察序列(o1,o2,...,on),求使得该观察序列出现的可能性最大的模型参数(包括初始状态概率矩阵π,状态转移矩阵A,发射矩阵B)。这正好就是EM算法要求解的问题:已知一系列的观察值X,在隐含变量Y未知的情况下求最佳参数θ*,使得:

在中文词性标注里,根据为训练语料,我们观察到了一系列的词(对应EM中的X),如果每个词的词性(即隐藏状态)也是知道的,那它就不需要用EM来求模型参数θ了,因为Y是已知的,不存在隐含变量了。当没有隐含变量时,直接用maximum likelihood就可以把模型参数求出来。

预备知识

首先你得对下面的公式表示认同。

以下都是针对相互独立的事件,

P(A,B)=P(B|A)*P(A)

P(A,B,C)=P(C)*P(A,B|C)=P(A,C|B)*P(B)=P(B,C|A)*P(A)

P(A,B,C,D)=P(D)*P(A,B|D)*P(C|A)=P(D)*P(A,B|D)*P(C|B)

P(A,B|C)=P(D1,A,B|C)+P(D2,A,B|C)     D1,D2是事件D的一个全划分

理解了上面几个式子,你也就能理解本文中出现的公式是怎么推导出来的了。

EM算法求解

我们已经知道如果隐含变量Y是已知的,那么求解模型参数直接利用Maximum Likelihood就可以了。EM算法的基本思路是:随机初始化一组参数θ(0),根据后验概率Pr(Y|X;θ)来更新Y的期望E(Y),然后用E(Y)代替Y求出新的模型参数θ(1)。如此迭代直到θ趋于稳定。

在HMM问题中,隐含变量自然就是状态变量,要求状态变量的期望值,其实就是求时刻ti观察到xi时处于状态si的概率,为了求此概率,需要用到向前变量和向后变量。

向前变量

向前变量 是假定的参数

它表示t时刻满足状态,且t时刻之前(包括t时刻)满足给定的观测序列的概率。

  1. 令初始值
  2. 归纳法计算
  3. 最后计算
复杂度
向后变量
向后变量                
它表示在时刻t出现状态 ,且t时刻以后的观察序列满足 的概率。
  1. 初始值
  2. 归纳计算

E-Step

定义变量为t时刻处于状态i,t+1时刻处于状态j的概率。

        

定义变量表示t时刻呈现状态i的概率。

实际上       

    

 

是从其他所有状态转移到状态i的次数的期望值。

是从状态i转移出去的次数的期望值。

是从状态i转移到状态j的次数的期望值。

M-Step

是在初始时刻出现状态i的频率的期望值,
是从状态i转移到状态j的次数的期望值   除以    从状态i转移出去的次数的期望值,
是在状态j下观察到活动为k的次数的期望值    除以    从其他所有状态转移到状态j的次数的期望值,
 
然后用新的参数 再来计算向前变量、向后变量、 。如此循环迭代,直到前后两次参数的变化量小于某个值为止。
下面给出我的java代码:
View Code 

package nlp;
import java.util.ArrayList;

public class BaumWelch {

    int M; // 隐藏状态的种数
    int N; // 输出活动的种数
    double[] PI; // 初始状态概率矩阵
    double[][] A; // 状态转移矩阵
    double[][] B; // 混淆矩阵

    ArrayList observation = new ArrayList(); // 观察到的集合
    ArrayList state = new ArrayList(); // 中间状态集合
    int[] out_seq = { 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1,
, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1 }; // 测试用的观察序列
    int[] hidden_seq = { 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1,
, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1 }; // 测试用的隐藏状态序列
    int T = 32; // 序列长度为32

    double[][] alpha = new double[T][]; // 向前变量
    double PO;
    double[][] beta = new double[T][]; // 向后变量
    double[][] gamma = new double[T][];
    double[][][] xi = new double[T - 1][][];

    // 初始化参数。Baum-Welch得到的是局部最优解,所以初始参数直接影响解的好坏
    public void initParameters() {
        M = 2;
        N = 2;
        PI = new double[M];
        PI[0] = 0.5;
        PI[1] = 0.5;
        A = new double[M][];
        B = new double[M][];
        for (int i = 0; i < M; i++) {
            A[i] = new double[M];
            B[i] = new double[N];
        }
        A[0][0] = 0.8125;
        A[0][1] = 0.1875;
        A[1][0] = 0.2;
        A[1][1] = 0.8;
        B[0][0] = 0.875;
        B[0][1] = 0.125;
        B[1][0] = 0.25;
        B[1][1] = 0.75;

        observation.add(1);
        observation.add(2);
        state.add(1);
        state.add(2);

        for (int t = 0; t < T; t++) {
            alpha[t] = new double[M];
            beta[t] = new double[M];
            gamma[t] = new double[M];
        }
        for (int t = 0; t < T - 1; t++) {
            xi[t] = new double[M][];
            for (int i = 0; i < M; i++)
                xi[t][i] = new double[M];
        }
    }

    // 更新向前变量
    public void updateAlpha() {
        for (int i = 0; i < M; i++) {
            alpha[0][i] = PI[i] * B[i][observation.indexOf(out_seq[0])];
        }
        for (int t = 1; t < T; t++) {
            for (int i = 0; i < M; i++) {
                alpha[t][i] = 0;
                for (int j = 0; j < M; j++) {
                    alpha[t][i] += alpha[t - 1][j] * A[j][i];
                }
                alpha[t][i] *= B[i][observation.indexOf(out_seq[t])];
            }
        }
    }

    // 更新观察序列出现的概率,它在一些公式中当分母
    public void updatePO() {
        for (int i = 0; i < M; i++)
            PO += alpha[T - 1][i];
    }

    // 更新向后变量
    public void updateBeta() {
        for (int i = 0; i < M; i++) {
            beta[T - 1][i] = 1;
        }
        for (int t = T - 2; t >= 0; t--) {
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < M; j++) {
                    beta[t][i] += A[i][j]
                            * B[j][observation.indexOf(out_seq[t + 1])]
                            * beta[t + 1][j];
                }
            }
        }
    }

    // 更新xi
    public void updateXi() {
        for (int t = 0; t < T - 1; t++) {
            double frac = 0.0;
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < M; j++) {
                    frac += alpha[t][i] * A[i][j]
                            * B[j][observation.indexOf(out_seq[t + 1])]
                            * beta[t + 1][j];
                }
            }
            for (int i = 0; i < M; i++) {
                for (int j = 0; j < M; j++) {
                    xi[t][i][j] = alpha[t][i] * A[i][j]
                            * B[j][observation.indexOf(out_seq[t + 1])]
                            * beta[t + 1][j] / frac;
                }
            }
        }
    }

    // 更新gamma
    public void updateGamma() {
        for (int t = 0; t < T - 1; t++) {
            double frac = 0.0;
            for (int i = 0; i < M; i++) {
                frac += alpha[t][i] * beta[t][i];
            }
            // double frac = PO;
            for (int i = 0; i < M; i++) {
                gamma[t][i] = alpha[t][i] * beta[t][i] / frac;
            }
            // for(int i=0;i
            // gamma[t][i]=0;
            // for(int j=0;j
            // gamma[t][i]+=xi[t][i][j];
            // }
        }
    }

    // 更新状态概率矩阵
    public void updatePI() {
        for (int i = 0; i < M; i++)
            PI[i] = gamma[0][i];
    }

    // 更新状态转移矩阵
    public void updateA() {
        for (int i = 0; i < M; i++) {
            double frac = 0.0;
            for (int t = 0; t < T - 1; t++) {
                frac += gamma[t][i];
            }
            for (int j = 0; j < M; j++) {
                double dem = 0.0;
                // for (int t = 0; t < T - 1; t++) {
                // dem += xi[t][i][j];
                // for (int k = 0; k < M; k++)
                // frac += xi[t][i][k];
                // }
                for (int t = 0; t < T - 1; t++) {
                    dem += xi[t][i][j];
                }
                A[i][j] = dem / frac;
            }
        }
    }

    // 更新混淆矩阵
    public void updateB() {
        for (int i = 0; i < M; i++) {
            double frac = 0.0;
            for (int t = 0; t < T; t++)
                frac += gamma[t][i];
            for (int j = 0; j < N; j++) {
                double dem = 0.0;
                for (int t = 0; t < T; t++) {
                    if (out_seq[t] == observation.get(j))
                        dem += gamma[t][i];
                }
                B[i][j] = dem / frac;
            }
        }
    }

    // 运行Baum-Welch算法
    public void run() {
        initParameters();
        int iter = 22; // 迭代次数
        while (iter-- > 0) {
            // E-Step
            updateAlpha();
            // updatePO();
            updateBeta();
            updateGamma();
            updatePI();
            updateXi();
            // M-Step
            updateA();
            updateB();
        }
    }

    public static void main(String[] args) {
        BaumWelch bw = new BaumWelch();
        bw.run();
        System.out.println("训练后的初始状态概率矩阵:");
        for (int i = 0; i < bw.M; i++)
            System.out.print(bw.PI[i] + "\t");
        System.out.println();
        System.out.println("训练后的状态转移矩阵:");
        for (int i = 0; i < bw.M; i++) {
            for (int j = 0; j < bw.M; j++) {
                System.out.print(bw.A[i][j] + "\t");
            }
            System.out.println();
        }
        System.out.println("训练后的混淆矩阵:");
        for (int i = 0; i < bw.M; i++) {
            for (int j = 0; j < bw.N; j++) {
                System.out.print(bw.B[i][j] + "\t");
            }
            System.out.println();
        }
    }
}
迭代22次后得到的参数:
训练后的初始状态概率矩阵:
6.72801479161809E-301.0
训练后的状态转移矩阵:
0.76720211710795320.23282165928765827
0.357061195165864760.6429096688758965
训练后的混淆矩阵:
0.99589658628791480.004103413712085399
2.135019831171061E-60.9999978649801687
【原文地址】http://www.cnblogs.com/zhangchaoyang/articles/2220398.html

你可能感兴趣的:(【转】向前-向后算法(forward-bac…)