mahout 0.7关于随机选择数据的一个bug

[本文链接:http://www.cnblogs.com/breezedeus/archive/2012/09/05/2671572.html,转载请注明出处]

 

在K-means聚类算法里,我们首先需要在已有的数据点中选取K个点作为初始中心点。这个bug就出现在中心点的随机选取上,mahout的实现不是真的随机。

【位置】:

     org.apache.mahout.clustering.kmeans.RandomSeedGenerator#buildRandom(...) , 行 88 - 110 这段。

 

我简化了一下,mahout的随机抽取逻辑如下:

   1: /**
   2:  * Sample K integers from integer interval [0, N).
   3:  * @param N
   4:  * @param K
   5:  * @return
   6:  */
   7: private List<Integer> generateMahoutRandomSeed( int N, int K) {
   8:     List<Integer> chosen = Lists. newArrayListWithCapacity(K);
   9:     Random random = RandomUtils. getRandom();
  10:     for ( int n = 0; n < N; ++n) {
  11:         int currentSize = chosen.size();
  12:         if (currentSize < K) {
  13:             chosen.add(n);
  14:         } else if (random.nextInt(currentSize + 1) != 0) {
  15:             int indexToRemove = random.nextInt(currentSize); // evict one chosen randomly
  16:             chosen.remove(indexToRemove);
  17:             chosen.add(n);
  18:         }
  19:     }
  20:     return chosen;
  21: }
 
mahout上面的抽取逻辑是没法做到真正随机的,越在后面的数最终被抽取的概率会越大。下面是我的修正算法:
   1: private List<Integer> generateBDRandomSeed(int N, int K) {
   2:     List<Integer> chosen = Lists.newArrayListWithCapacity(K);
   3:     Random random = RandomUtils.getRandom();
   4:     for (int n = 0; n < N; ++n) {
   5:         int currentSize = chosen.size();
   6:         if (currentSize < K) {
   7:             chosen.add(n);
   8:         } else if (random.nextInt(n + 1) < K) {
   9:             int indexToRemove = random.nextInt(currentSize); // actually currentSize is always equal to K here
  10:             chosen.remove(indexToRemove);
  11:             chosen.add(n);
  12:         }
  13:     }
  14:     return chosen;
  15: }

 

为了说明问题,我对上面的代码写了个测试函数:

   1: @Test
   2: public void testMahoutRandomSeedGenerator() {
   3:     int N = 11;
   4:     int K = 3;
   5:     int numLoops = 100000;
   6:     int[] times = new int [N];
   7:     Arrays. fill(times, 0);
   8:     for ( int loop = 0; loop < numLoops; ++loop) {
   9:         //List<Integer> chosen = generateMahoutRandomSeed(N, K);
  10:         List<Integer> chosen = generateWjlRandomSeed(N, K);
  11:         for (Integer i : chosen) {
  12:             ++times[i];
  13:         }
  14:     }
  15:     for ( int n = 0; n < N; ++n) {
  16:         System. out .println(times[n] / ( double)(numLoops*K) );
  17:     }
  18: }

 

使用generateMahoutRandomSeed产生的结果如下:

0.03344333333333333

0.033523333333333336

0.033356666666666666

0.033036666666666666

0.044596666666666666

0.059013333333333334

0.07856

0.10567666666666667

0.14029

0.18848333333333334

0.25002

而使用修正后的generateBDRandomSeed,其结果如下:

0.09096

0.09011

0.09051666666666666

0.09143333333333334

0.09082

0.09034333333333333

0.09032

0.09103

0.09162

0.09148666666666666

0.09136

 

数学上可以证明我的算法是对的,证明可见我之前讲面试题目的一篇老博文。感兴趣的童鞋也可以想想为什么generateMahoutRandomSeed的实现不对:)

 

mahout的这个问题也可能出现在它其他随机抽取相关的代码中,所以建议用到mahout随机抽取代码的同学check一下再使用。

你可能感兴趣的:(Mahout)