HMM EM & Viterbi

只是为了实现这几个算法,没有很大的trick
package tseg;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Random;

/**
 * @author yinxs E-mail: [email protected]
 * @version 创建时间:2011-12-27 下午01:54:02 类说明
 */
public class Hmm {
	private static double diff = 0.001;
	private static double minProb = 0.000001;
	private int max_states;
	private int max_symbols;
	private int num_observations;
	private int obserArray[];
	private double transMatrix[][];
	private double emissMatrix[][];
	private double pi[];
	private double forward[][];
	private double backward[][];
	private double gamma[][];
	private double xi[][][];
	private Random rand;

	private double viterbi[][];
	private double backTrace[][];

	public Hmm() {
		// // for test only
		// max_states = 3;
		// max_symbols = 3;
		// num_observations = 3;

		rand = new Random();
		// transMatrix = new double[max_states][max_states];
		// emissMatrix = new double[max_states][max_symbols];
		// pi = new double[max_states];
		// forward = new double[max_states][num_observations];
		// backward = new double[max_states][num_observations];
		// gamma = new double[max_states][num_observations];
		// xi = new double[max_states][max_states][num_observations];
		//
		// viterbi = new double[max_states][num_observations];
		// backTrace = new double[max_states][num_observations];
	}

	/**
	 * 使用随机数初始化转移概率矩阵和发射概率矩阵 其实还是用平分吧,最大熵
	 */
	// private double getRandProb() {
	// return (Math.abs(rand.nextInt())%100) / 100;
	// }
	private void RandInitMatrix() {
		// be zero
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_states; ++j) {
				transMatrix[i][j] = 0.0;
			}
		}
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_symbols; ++j) {
				emissMatrix[i][j] = 0.0;
			}
		}

		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_states; ++j) {
				transMatrix[i][j] = 1.0 / max_states;
			}
		}
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_symbols; ++j) {
				emissMatrix[i][j] = 1.0 / max_symbols;
			}
		}

		for (int i = 0; i < max_states; ++i) {
			pi[i] = 1.0 / max_symbols;
		}
	}

	// for test only
	public void setObserArray() {

		obserArray = new int[num_observations];
		obserArray[0] = 0;
		obserArray[1] = 1;
		obserArray[2] = 2;
	}

	/**
	 * 从文件中初始化 整个模型,暂时选用指定的文件
	 * 
	 * @throws IOException
	 */
	public void InitFile() throws IOException {
		FileReader fr = new FileReader(
				"D:\\workspace_weibo\\hmm\\src\\tseg\\model.hmm");
		BufferedReader br = new BufferedReader(fr);
		String line = br.readLine();
		max_states = Integer.parseInt(line);
		line = br.readLine();

		max_symbols = Integer.parseInt(line);

		transMatrix = new double[max_states][max_states];
		emissMatrix = new double[max_states][max_symbols];
		pi = new double[max_states];

		String tmp[] = new String[max_states];
		String tmp2[] = new String[max_symbols];

		line = br.readLine();
		tmp = line.split(" ");
		for (int i = 0; i < max_states; ++i) {
			pi[i] = Double.parseDouble(tmp[i]);
		}

		for (int i = 0; i < max_states; ++i) {
			line = br.readLine();
			tmp = line.split(" ");
			for (int j = 0; j < max_states; ++j) {
				transMatrix[i][j] = Double.parseDouble(tmp[j]);
			}
		}

		for (int i = 0; i < max_states; ++i) {
			line = br.readLine();
			tmp2 = line.split(" ");
			for (int j = 0; j < max_symbols; ++j) {
				emissMatrix[i][j] = Double.parseDouble(tmp2[j]);
			}
		}

		br.close();
		fr.close();
	}

	/**
	 * 将模型参数输出到文件
	 * 
	 * @throws IOException
	 */
	public void DumpFile() throws IOException {

	}

	/**
	 * 从文件中读取观测值
	 * 
	 * @throws IOException
	 */
	public void ReadObservation() throws IOException {
		FileReader fr = new FileReader(
				"D:\\workspace_weibo\\hmm\\src\\tseg\\test.hmm");
		BufferedReader br = new BufferedReader(fr);
		String line = br.readLine();

		num_observations = line.split(" ").length;

		forward = new double[max_states][num_observations];
		backward = new double[max_states][num_observations];
		gamma = new double[max_states][num_observations];
		xi = new double[max_states][max_states][num_observations];

		viterbi = new double[max_states][num_observations];
		backTrace = new double[max_states][num_observations];

		obserArray = new int[num_observations];
		for (int i = 0; i < num_observations; ++i) {
			obserArray[i] = Integer.parseInt(line.split(" ")[i]);
		}
	}

	public double Forward() {
		double retval = 0.0;
		// 清零
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < num_observations; ++j) {
				forward[i][j] = 0;
			}
		}
		// 初始化
		for (int i = 0; i < max_states; ++i) {
			forward[i][0] = pi[i] * emissMatrix[i][obserArray[0]];
		}
		// 递推
		for (int t = 1; t < num_observations; ++t) {
			for (int s = 0; s < max_states; ++s) {
				for (int ps = 0; ps < max_states; ++ps) {
					forward[s][t] += forward[ps][t - 1] * transMatrix[ps][s]
							* emissMatrix[s][obserArray[t]];
				}
			}
		}
		// 终态
		for (int s = 0; s < max_states; ++s) {
			retval += forward[s][num_observations - 1];
		}
		return retval;
	}

	private double Backward() {
		double retval = 0.0;
		// 清零
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < num_observations; ++j) {
				backward[i][j] = 0;
			}
		}
		// 初始化
		for (int i = 0; i < max_states; ++i) {
			backward[i][num_observations - 1] = 1.0;
		}
		// 递推
		for (int t = num_observations - 2; t >= 0; --t) {
			for (int s = 0; s < max_states; ++s) {
				for (int bs = 0; bs < max_states; ++bs) {
					backward[s][t] += backward[bs][t + 1] * transMatrix[s][bs]
							* emissMatrix[bs][obserArray[t + 1]];
				}
			}
		}
		// 终态
		for (int s = 0; s < max_states; ++s) {
			retval += backward[s][0] * pi[s] * emissMatrix[s][obserArray[0]];
		}
		return retval;
	}

	private void Gamma() {
		double denominator = Forward();
		Backward();
		assert (denominator > 0);
		for (int s = 0; s < max_states; ++s) {
			for (int t = 0; t < num_observations; ++t) {
				gamma[s][t] = forward[s][t] * backward[s][t] / denominator;
			}
		}
	}

	private void Xi() {
		double denominator = Forward();
		Backward();
		assert (denominator > 0);
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_states; ++j) {
				for (int t = 0; t < num_observations - 1; ++t) {
					xi[i][j][t] = forward[i][t] * backward[j][t + 1]
							* transMatrix[i][j]
							* emissMatrix[j][obserArray[t + 1]] / denominator;
				}
			}
		}
	}

	private double ReestimateTransMatrix() {
		double diffSum = 0.0;
		double numerator[] = new double[max_states];
		for (int t = 0; t < max_states; ++t) {
			numerator[t] = 0.0;
		}
		double denominator = 0.0;
		double prob = 0.0;

		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_states; ++j) {
				for (int t = 0; t < num_observations; ++t) {
					numerator[j] += xi[t][i][j];
				}
				denominator += numerator[j];
			}
			for (int j = 0; j < max_states; ++j) {
				prob = numerator[j] / denominator;
				diffSum += Math.abs(prob - transMatrix[i][j]);
				transMatrix[i][j] = (prob > minProb) ? prob : minProb;
			}
			for (int t = 0; t < max_states; ++t) {
				numerator[t] = 0.0;
			}
		}
		return diffSum;
	}

	private double ReestimateEmissMatrix() {
		double diffSum = 0.0;
		double denominator = 0.0;
		double numerator[] = new double[max_symbols];
		for (int i = 0; i < max_symbols; ++i) {
			numerator[i] = 0;
		}
		double prob = 0.0;
		for (int s = 0; s < max_states; ++s) {
			for (int t = 0; t < num_observations; ++t) {
				denominator += gamma[s][t];
				numerator[obserArray[t]] += gamma[s][t];
			}
			for (int t = 0; t < max_symbols; ++t) {
				prob = numerator[t] / denominator;
				diffSum += Math.abs(prob - emissMatrix[s][t]);
				emissMatrix[s][t] = (prob > minProb) ? prob : minProb;
			}
			for (int i = 0; i < max_symbols; ++i) {
				numerator[i] = 0;
			}
		}
		return diffSum;
	}

	public void ForwardBackward() {
		double diffSumTrans = 0.0;
		double diffSumEmiss = 0.0;
		// Initialize
		RandInitMatrix();
		do {
			// E-step
			Gamma();
			Xi();
			// M-step
			diffSumTrans = ReestimateTransMatrix();
			diffSumEmiss = ReestimateEmissMatrix();
			// for debug
			System.out.print("转移概率增量:");
			System.out.println(diffSumTrans);
			System.out.print("发射概率增量:");
			System.out.println(diffSumEmiss);
			System.out.println();

			showProbs();
		} while (diffSumTrans + diffSumEmiss > Hmm.diff);
	}

	public void showProbs() {
		System.out.print("转移概率:\n");
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_states; ++j) {
				System.out.print(transMatrix[i][j]);
				System.out.print("\t");
			}
			System.out.print("\n");
		}
		System.out.print("发射概率:\n");
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < max_symbols; ++j) {
				System.out.print(emissMatrix[i][j]);
				System.out.print("\t");
			}
			System.out.print("\n");
		}
	}

	public double Viterbi() {
		double retval = 0.0;
		int finalPos = 0;
		double prob = 0.0;
		// 清零
		for (int i = 0; i < max_states; ++i) {
			for (int j = 0; j < num_observations; ++j) {
				viterbi[i][j] = 0;
			}
		}
		// 初始化
		for (int i = 0; i < max_states; ++i) {
			viterbi[i][0] = pi[i] * emissMatrix[i][obserArray[0]];
			backTrace[i][0] = 0; // 这个没什么用处 应该就是标记开始
		}
		// 递推
		for (int t = 1; t < num_observations; ++t) {
			for (int s = 0; s < max_states; ++s) {
				for (int ps = 0; ps < max_states; ++ps) {
					prob = viterbi[ps][t - 1] * transMatrix[ps][s]
							* emissMatrix[s][obserArray[t]];
					if (viterbi[s][t] < prob) {
						viterbi[s][t] = prob;
						backTrace[s][t] = ps;
					}
				}
			}
		}
		// 终态
		for (int t = num_observations - 1; t >= 0; --t) {
			for (int s = 0; s < max_states; ++s) {
				if (retval < viterbi[s][t]) {
					retval = viterbi[s][t];
					finalPos = s;
				}
			}
			System.out.print("序号:");
			System.out.print(t);
			System.out.println();
			System.out.print("标记:");
			System.out.print(finalPos);
			System.out.print("\t");
			System.out.print("概率:");
			System.out.print(retval);
			System.out.println();
		}
		return retval;
	}

	public static void main(String[] args) throws IOException {
		Hmm hmm = new Hmm();
		// hmm.setObserArray();
		// hmm.ForwardBackward();
		hmm.InitFile();
		hmm.showProbs();
		hmm.ReadObservation();
		hmm.Viterbi();
	}
}

你可能感兴趣的:(算法,String,Random,Class)