IPFP(iterative proportional fitting procedure)算法实现

如下是我原创的IPFP算法实现,可以和大家交流,希望有高手批评指正

package IPFProcedure;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

public class IPFProcedureFuntion {

	/**
	 *  算法描述如下: iterative proportional fitting procedure
	 *  (1) 初始值: Q_0(X) , R = {R(Y^1),R(Y^2),R(Y^3)}
	 *  (2) for k = 1 ,重复以下迭代过程,直到收敛
	 *      2.1  i = k mod m;
	 *      2.2  if Q_k-1(Y^i)!= 0  Q_k(X) = Q_k-1(X) * R(Y^i) / Q_k-1(Y^i) 
	 *           else Q_k(X) = 0
	 *      2.3  k = k + 1  
	 */
	/*
	 * 如下是各个参数
	 */
	public double c  = 0;//c是一个参数值 伊普西龙 ε 
    public ArrayList> R;//约束集
    public ArrayList ConstraitVariableList;
    public int m = 0;//约束集个数 等于 R.size()
    public String[] X;//随机变量标识
    public int variableNum;
    public LinkedHashMap Q ;//存放迭代结果
    public LinkedHashMap lastQ ;//上次迭代结果,当一致时收敛
    public int iterativeCount = 1;//实验迭代次数
    public HashMap convergenceMap ;//收敛情况
    public boolean isConvergenced = false;//是否已经收敛
    
	/*
     * 初始化工作,给各个field variable赋值
     */
    public  void init(String[] X, ArrayList> R, HashMap Q){
    	
    	this.X = new String[X.length];
    	for(int i = 0; i < X.length; i++)
    		this.X[i] = X[i];
    	this.variableNum = X.length;
    	
    	this.R = new ArrayList>();
//    	for(HashMap r : R)
//    		this.R.add(r);
    	ConstraitVariableList = new ArrayList();
    	for(int i = 0; i < R.size(); i++){
    		this.R.add(R.get(i));
    		Boolean[] e = new Boolean[variableNum];
    		Iterator ir = R.get(i).keySet().iterator();
    		if(ir.hasNext()){
    			String s = ir.next();
    			for(int j = 0; j < s.length(); j++){
    				char c = s.charAt(j);
    				if(c == 'T' || c == 'F')
    					e[j] = true;
    				else if(c == 'o')
    					e[j] = false;
    				else
    					System.out.println("constrait set data error!");
    			}
    		}
    			
    		ConstraitVariableList.add(e);
    	}
    	
    	this.m = R.size();
    	
    	
    	this.Q = new LinkedHashMap();
    	this.Q.putAll(Q);
    	convergenceMap = new HashMap();
    	Iterator i = Q.keySet().iterator();
    	while(i.hasNext()){
    		convergenceMap.put(i.next(), false);
    	}
    	lastQ = new LinkedHashMap();
    	Set set = this.Q.keySet();
    	Iterator ir = set.iterator();
    	while(ir.hasNext()){
    		lastQ.put(ir.next(), 0d);
    	}
    	
    }
    
    /*
     * 对于迭代收敛的判断还是不正确,所以从新写了这个方法如下
     */
    public void iterativeFunctionFittingProcedureConvergenceModify(){
    	this.isConvergenced = judgeIsConvergenced();
    	if(this.isConvergenced){
    		System.out.println("联合概率分布已经收敛!结果如下:");
    		printResult();
    		return;
    	}else if(iterativeCount == 100){
    		System.out.println("迭代次数超过100!");
    		return;
    	}else{
    		Iterator QKeysIterator1 = this.Q.keySet().iterator();
    		while(QKeysIterator1.hasNext()){
    			String s = QKeysIterator1.next();
    			this.lastQ.put(s, this.Q.get(s));
    		}
    		Iterator QKeysIterator2 = this.Q.keySet().iterator();
    		int i = iterativeCount % m;
    		while(QKeysIterator2.hasNext()){
    			String s = QKeysIterator2.next();
    			
    			Double f ;
				Double q1 = this.lastQ.get(s);
				char[] c = new char[variableNum];
				for(int k = 0; k < c.length; k++){
					if(ConstraitVariableList.get(i)[k] == false)
						c[k] = 'o';
					else
						c[k] = s.charAt(k);
				}		
				Double r = R.get(i).get(String.valueOf(c));
				Double q2 = accMarginalProbabilityFunction(String.valueOf(c));
				if(q2 == 0d){
					f = 0d;
				}
				else
					f =  q1 * r / q2;
				this.Q.put(s, f);
				
    		}
    		System.out.println("the IPFP Step " + iterativeCount + " has completed");
    		System.out.println("收敛情况: " + isConvergenced);
    		printResult();
        	System.out.println();
        	iterativeCount++;
        	iterativeFunctionFittingProcedureConvergenceModify();
    		
    		
    	}
    }
    public void iterativeFunctionFittingProcedure(){
    	if(isConvergenced == true){
    		System.out.println("联合概率分布已经收敛!结果如下:");
    		printResult();
    		return;
    	}else if(iterativeCount == 100){
    		System.out.println("迭代次数超过100!");
    		return;}
    	Iterator QStringIterator = this.Q.keySet().iterator();
    	while(QStringIterator.hasNext()){
    		String s = QStringIterator.next();
    		if(convergenceMap.get(s) == false){
    			if(this.Q.get(s).equals((this.lastQ.get(s))))//用等号判断 还是equals()方法	
    				convergenceMap.put(s, true);
    			else{//迭代计算
    				this.lastQ.put(s, this.Q.get(s));
    				int i = iterativeCount % m;//得到i,公式2.1
    				Double f ;
    				Double q1 = this.lastQ.get(s);
    				char[] c = new char[variableNum];
    				for(int k = 0; k < c.length; k++){
    					if(ConstraitVariableList.get(i)[k] == false)
    						c[k] = 'o';
    					else
    						c[k] = s.charAt(k);
    				}		
    				Double r = R.get(i).get(String.valueOf(c));
    				Double q2 = accMarginalProbabilityFunction(String.valueOf(c));
    				if(q2 == 0d){
    					f = 0d;
    					convergenceMap.put(s, true);
    				}
    				else
    					f =  q1 * r / q2;
    				this.Q.put(s, f);
    			}
    		
    		}
    	}
    	System.out.println("the IPFP Step " + iterativeCount + " has completed");
    	
    	printResult();
    	isConvergenced = judgeConvergence();
    	System.out.println("收敛情况: " + isConvergenced);
    	printQConvergence();
    	System.out.println();
    	iterativeCount++;
    	iterativeFunctionFittingProcedure();
    }
    /*
     * 根据查询String从联合概率分布Q中计算概率Q_k-1(Y^i)
     */
    public Double accMarginalProbabilityFunction(String query){
    	Iterator i = this.lastQ.keySet().iterator();
    	Double result = 0d;
    	while(i.hasNext()){
    		String key = i.next();
    		if(judgeQueryStringIsMatched(key, query))
    			result += this.lastQ.get(key);
    	}
    	return result;
    }
    /*
     * 根据查询的String查找联合概率分布Q
     */
    public boolean judgeQueryStringIsMatched(String key,String query){
    	if(key.length() != variableNum || query.length() != variableNum)
    		System.out.println("概率查询出错,请检查数据输入初始化是否正确");
    	for(int i = 0; i < query.length(); i++){
    		if(query.charAt(i) == 'T' && key.charAt(i) == 'F'|| query.charAt(i) == 'F' && key.charAt(i) == 'T')
    			return false;
    	}
    	return true;
    }
    /*
     * 根据联合概率表 Q和lastQ是否一致来判断是否收敛
     */
    public boolean judgeIsConvergenced(){
    	Iterator iq = Q.keySet().iterator();
    	while(iq.hasNext()){
    		String s = iq.next();
    		if(!lastQ.get(s).equals(Q.get(s)))
    			return false;	
    	}
    	return true;
    }
    /*
     * 判断收敛情况,ConvergenceMap存储各个X=x时的各个收敛情况,但是有疑问
     */
    public boolean judgeConvergence(){
    	Iterator i = convergenceMap.keySet().iterator();
    	while(i.hasNext()){
    		if(convergenceMap.get(i.next()) == false)
    			return false;
    	}
    	return true;
    }
    
    public void printResult(){//也可以写到csv中去
    	if(isConvergenced){
    		int num = iterativeCount - 1;
    		System.out.println( num + " Step, " + "最后收敛的联合概率分布如下: ");}
    	else
    		System.out.println("第 "+ iterativeCount +" 次迭代概率分布如下: ");
//    	Iterator si = Q.keySet().iterator();
//    	while(si.hasNext()){
//    		String s = si.next();
//    		printVariable(s);
//    	}
    	String[] keys =  Q.keySet().toArray(new String[Q.keySet().size()]);
    	Comparator c = new Comparator() {

			public int compare(String s1,String s2) {
				// TODO Auto-generated method stub
				for(int i = 0; i s2.charAt(i))
							return -1;
					else if(s1.charAt(i) < s2.charAt(i))
							return 1;
				}
				return 0;
			}
    	};
    	Arrays.sort(keys, c);
    	for(int i = 0 ; i < keys.length; i++)
    		printVariable(keys[i]);
    	
    	System.out.println("联合概率Q打印完毕!");
    }
    
    public void printVariable(String s){
    	StringBuilder sb = new StringBuilder("");
    	for(int i = 0; i < s.length(); i++){
    		if(s.charAt(i) == 'T')
    			sb.append(X[i] + " = " + "T");
    		else if(s.charAt(i) == 'F')
    			sb.append(X[i] + " = " + "F");
    		sb.append(",");
    	}
    	sb.append("probality is " + Q.get(s));
    	System.out.println(sb);
    }
    
    public void printQConvergence(){
    	Iterator ir = convergenceMap.keySet().iterator();
    	while(ir.hasNext()){
    		String s = ir.next();
    		StringBuilder sb = new StringBuilder("");
        	for(int i = 0; i < s.length(); i++){
        		if(s.charAt(i) == 'T')
        			sb.append(X[i] + " = " + "T");
        		else if(s.charAt(i) == 'F')
        			sb.append(X[i] + " = " + "F");
        		sb.append(",");
        	}
        	sb.append("convergenced is " + convergenceMap.get(s));
        	System.out.println(sb);
        
    	}
    }
    public static void main(String[] args) {
		// TODO Auto-generated method stub
    	IPFProcedureFuntion ipfp = new IPFProcedureFuntion();
    	ipfp.setC(4/20d);
    	Double C = 4 / 20d;
    	String[] X = {"X","Y","Z"};
    	
    	HashMap Q = new HashMap();
    	Q.put("TTT", 0.125d);
    	Q.put("TTF", 0.125d);
    	Q.put("TFT", 0.125d);
    	Q.put("TFF", 0.125d);
    	Q.put("FTT", 0.125d);
    	Q.put("FTF", 0.125d);
    	Q.put("FFT", 0.125d);
    	Q.put("FFF", 0.125d);
    	
    	ArrayList> R = new ArrayList>();
    	
    	HashMap R2 = new HashMap();
    	R2.put("TTo", 1/2d - C);
    	R2.put("TFo", C);
    	R2.put("FTo", C);
    	R2.put("FFo", 1/2d - C);
    	R.add(R2);
    	
    	HashMap R3 = new HashMap();
    	R3.put("oTT", 1/2d - C);
    	R3.put("oTF", C);
    	R3.put("oFT", C);
    	R3.put("oFF", 1/2d - C);
    	R.add(R3);
    	
    	HashMap R1 = new HashMap();
    	R1.put("ToT", C);
    	R1.put("ToF", 1/2d - C);
    	R1.put("FoT", 1/2d - C);
    	R1.put("FoF", C);
    	R.add(R1);
    	
    	ipfp.init(X, R, Q);
    	ipfp.iterativeFunctionFittingProcedureConvergenceModify();
    	ipfp.printResult();
    	
	}

	public double getC() {
		return c;
	}

	public void setC(double c) {
		this.c = c;
	}
	

}


你可能感兴趣的:(java代码,算法,string,c,query,equals,csv)