Weka中EM算法详解

 1  private void EM_Init (Instances inst)  2     throws Exception {  3     int i, j, k;  4 

 5     // 由于EM算法对初始值较敏感,故选择run k means 10 times and choose best solution

 6     SimpleKMeans bestK = null;  7     double bestSqE = Double.MAX_VALUE;  8     for (i = 0; i < 10; i++) {  9       SimpleKMeans sk = new SimpleKMeans(); 10  sk.setSeed(m_rr.nextInt()); 11  sk.setNumClusters(m_num_clusters); 12       sk.setDisplayStdDevs(true); 13  sk.buildClusterer(inst); 14       //KMeans中各个cluster的平方误差

15       if (sk.getSquaredError() < bestSqE) { 16          

17           bestSqE = sk.getSquaredError(); 18           bestK = sk; 19  } 20  } 21     

22     /*************** KMeans Finds the best cluster number *****************/

23     

24     

25     // initialize with best k-means solution

26     m_num_clusters = bestK.numberOfClusters(); 27     // 每个样本所在各个集群的概率

28     m_weights = new double[inst.numInstances()][m_num_clusters]; 29     // 评估每个集群所对应的离散型属性的相关取值
30
m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs]; 31 // 每个集群所对应的连续性属性数所对应的相关取值(均值,标准偏差,样本权值(进行归一化)) 32 m_modelNormal = new double[m_num_clusters][m_num_attribs][3]; 33 // 每个集群所对应的先验概率 34 m_priors = new double[m_num_clusters]; 35 // 每个集群所对应的中心点 36 Instances centers = bestK.getClusterCentroids(); 37 // 每个集群所对应的标准差 38 Instances stdD = bestK.getClusterStandardDevs(); 39 // ??? Returns for each cluster the frequency counts for the values of each nominal attribute 40 int [][][] nominalCounts = bestK.getClusterNominalCounts(); 41 // 得到每个集群所对应的样本数 42 int [] clusterSizes = bestK.getClusterSizes(); 43 44 for (i = 0; i < m_num_clusters; i++) { 45 Instance center = centers.instance(i); 46 for (j = 0; j < m_num_attribs; j++) { 47 48 // 样本属性是离散型 49 if (inst.attribute(j).isNominal()) 50 { 51 m_model[i][j] = new DiscreteEstimator(m_theInstances.attribute(j).numValues() 52 , true); 53 for (k = 0; k < inst.attribute(j).numValues(); k++) { 54 m_model[i][j].addValue(k, nominalCounts[i][j][k]); 55 } 56 } 57 //// 样本属性是连续型 58 else 59 { 60 double minStdD = (m_minStdDevPerAtt != null)? m_minStdDevPerAtt[j]: m_minStdDev; 61 double mean = (center.isMissing(j))? inst.meanOrMode(j): center.value(j); 62 m_modelNormal[i][j][0] = mean; 63 double stdv = (stdD.instance(i).isMissing(j))? ((m_maxValues[j] - 64 m_minValues[j]) / (2 * m_num_clusters)): stdD.instance(i).value(j); 65 if (stdv < minStdD) 66 { 67 stdv = inst.attributeStats(j).numericStats.stdDev; 68 if (Double.isInfinite(stdv)) { 69 stdv = minStdD; 70 } 71 if (stdv < minStdD) { 72 stdv = minStdD; 73 } 74 } 75 if (stdv <= 0) { 76 stdv = m_minStdDev; 77 } 78 79 m_modelNormal[i][j][1] = stdv; 80 m_modelNormal[i][j][2] = 1.0; 81 } 82 } 83 } 84 85 86 for (j = 0; j < m_num_clusters; j++) { 87 // 计算每个集群的先验概率 88 m_priors[j] = clusterSizes[j]; 89 } 90 Utils.normalize(m_priors); 91 }

 

你可能感兴趣的:(算法)