hadoop下实现kmeans算法——一个mapreduce的实现方法

写mapreduce程序实现kmeans算法,我们的思路可能是这样的

1. 用一个全局变量存放上一次迭代后的质心

2. map里,计算每个质心与样本之间的距离,得到与样本距离最短的质心,以这个质心作为key,样本作为value,输出

3. reduce里,输入的key是质心,value是其他的样本,这时重新计算聚类中心,将聚类中心put到一个全部变量t中。

4. 在main里比较前一次的质心和本次的质心是否发生变化,如果变化,则继续迭代,否则退出。

本文的思路基本上是按照上面的步骤来做的,只不过有几个问题需要解决

1. Hadoop是不存在自定义的全局变量的,所以上面定义一个全局变量存放质心的想法是实现不了的,所以一个替代的思路是将质心存放在文件中

2. 存放质心的文件在什么地方读取,如果在map中读取,那么可以肯定我们是不能用一个mapreduce实现一次迭代,所以我们选择在main函数里读取质心,然后将质心set到configuration中,configuration在map和reduce都是可读

3. 如何比较质心是否发生变化,是在main里比较么,读取本次质心和上一次质心的文件然后进行比较,这种方法是可以实现的,但是显得不够高富帅,这个时候我们用到了自定义的counter,counter是全局变量,在map和reduce中可读可写,在上面的思路中,我们看到reduce是有上次迭代的质心和刚刚计算出来的质心的,所以直接在reduce中进行比较就完全可以,如果没发生变化,counter加1。只要在main里比较获取counter的值就行了。

梳理一下,具体的步骤如下

1. main函数读取质心文件

2. 将质心的字符串放到configuration中

3. 在mapper类重写setup方法,获取到configuration的质心内容,解析成二维数组的形式,代表质心

4. mapper类中的map方法读取样本文件,跟所有的质心比较,得出每个样本跟哪个质心最近,然后输出<质心,样本>

5. reducer类中重新计算质心,如果重新计算出来的质心跟进来时的质心一致,那么自定义的counter加1

6. main中获取counter的值,看是否等于质心,如果不相等,那么继续迭代,否在退出

具体的实现如下

1. pom依赖

这个要跟集群的一致,因为如果不一致在计算其他问题的时候没有问题,但是在使用counter的时候会出现问题

Java.lang.IncompatibleClassChangeError: Found interface org.apache.Hadoop.mapreduce.Counter, but class was expected

原因是:其实从2.0开始,org.apache.hadoop.mapreduce.Counter从1.0版本的class改为interface,可以看一下你导入的这个类是class还是interface,如果是class那么就是导包导入的不对,需要修改

2. 样本

实例样本如下

[plain]  view plain  copy

 


  1. 1,1  
  2. 2,2  
  3. 3,3  
  4. -3,-3  
  5. -4,-4  
  6. -5,-5  

3. 质心

这个质心是从样本中随机找的

[plain]  view plain  copy

 


  1. 1,1  
  2. 2,2  

4. 代码实现

首先定义一个Center类,这个类主要存放了质心的个数k,还有两个从hdfs上读取质心文件的方法,一个用来读取初始的质心,这个实在文件中,还有一个是用来读取每次迭代后的质心文件夹,这个是在文件夹中的,代码如下

Center类

[java]  view plain  copy

 


  1. public class Center {  
  2.   
  3.     protected static int k = 2;     //质心的个数  
  4.       
  5.     /** 
  6.      * 从初始的质心文件中加载质心,并返回字符串,质心之间用tab分割 
  7.      * @param path 
  8.      * @return 
  9.      * @throws IOException 
  10.      */  
  11.     public String loadInitCenter(Path path) throws IOException {  
  12.           
  13.         StringBuffer sb = new StringBuffer();  
  14.           
  15.         Configuration conf = new Configuration();  
  16.         FileSystem hdfs = FileSystem.get(conf);  
  17.         FSDataInputStream dis = hdfs.open(path);  
  18.         LineReader in = new LineReader(dis, conf);  
  19.         Text line = new Text();  
  20.         while(in.readLine(line) > 0) {  
  21.             sb.append(line.toString().trim());  
  22.             sb.append(”\t”);  
  23.         }  
  24.           
  25.         return sb.toString().trim();  
  26.     }  
  27.       
  28.     /** 
  29.      * 从每次迭代的质心文件中读取质心,并返回字符串 
  30.      * @param path 
  31.      * @return 
  32.      * @throws IOException 
  33.      */  
  34.     public String loadCenter(Path path) throws IOException {  
  35.           
  36.         StringBuffer sb = new StringBuffer();  
  37.           
  38.         Configuration conf = new Configuration();  
  39.         FileSystem hdfs = FileSystem.get(conf);  
  40.         FileStatus[] files = hdfs.listStatus(path);  
  41.           
  42.         for(int i = 0; i < files.length; i++) {  
  43.               
  44.             Path filePath = files[i].getPath();  
  45.             if(!filePath.getName().contains(“part”)) continue;  
  46.             FSDataInputStream dis = hdfs.open(filePath);  
  47.             LineReader in = new LineReader(dis, conf);  
  48.             Text line = new Text();  
  49.             while(in.readLine(line) > 0) {  
  50.                 sb.append(line.toString().trim());  
  51.                 sb.append(”\t”);  
  52.             }  
  53.         }  
  54.           
  55.         return sb.toString().trim();  
  56.     }  
  57. }  

KmeansMR类

[java]  view plain  copy

 


  1. public class KmeansMR {  
  2.   
  3.     private static String FLAG = “KCLUSTER”;  
  4.           
  5.     public static class TokenizerMapper   
  6.     extends Mapper{  
  7.           
  8.         double[][] centers = new double[Center.k][];  
  9.         String[] centerstrArray = null;  
  10.           
  11.         @Override  
  12.         public void setup(Context context) {  
  13.               
  14.             //将放在context中的聚类中心转换为数组的形式,方便使用  
  15.             String kmeansS = context.getConfiguration().get(FLAG);  
  16.             centerstrArray = kmeansS.split(”\t”);  
  17.             for(int i = 0; i < centerstrArray.length; i++) {  
  18.                 String[] segs = centerstrArray[i].split(”,”);  
  19.                 centers[i] = new double[segs.length];  
  20.                 for(int j = 0; j < segs.length; j++) {  
  21.                     centers[i][j] = Double.parseDouble(segs[j]);  
  22.                 }  
  23.             }  
  24.         }  
  25.           
  26.         public void map(Object key, Text value, Context context  
  27.                  ) throws IOException, InterruptedException {  
  28.               
  29.             String line = value.toString();  
  30.             String[] segs = line.split(”,”);  
  31.             double[] sample = new double[segs.length];  
  32.             for(int i = 0; i < segs.length; i++) {  
  33.                 sample[i] = Float.parseFloat(segs[i]);  
  34.             }  
  35.             //求得距离最近的质心  
  36.             double min = Double.MAX_VALUE;  
  37.             int index = 0;  
  38.             for(int i = 0; i < centers.length; i++) {  
  39.                 double dis = distance(centers[i], sample);  
  40.                 if(dis < min) {  
  41.                     min = dis;  
  42.                     index = i;  
  43.                 }  
  44.             }  
  45.               
  46.             context.write(new Text(centerstrArray[index]), new Text(line));  
  47.         }  
  48.     }  
  49.   
  50.     public static class IntSumReducer   
  51.     extends Reducer {  
  52.   
  53.         Counter counter = null;  
  54.           
  55.         public void reduce(Text key, Iterable values,   
  56.                     Context context  
  57.                     ) throws IOException, InterruptedException {  
  58.               
  59.             double[] sum = new double[Center.k];  
  60.             int size = 0;  
  61.             //计算对应维度上值的加和,存放在sum数组中  
  62.             for(Text text : values) {  
  63.                 String[] segs = text.toString().split(”,”);  
  64.                 for(int i = 0; i < segs.length; i++) {  
  65.                     sum[i] += Double.parseDouble(segs[i]);  
  66.                 }  
  67.                 size ++;  
  68.             }  
  69.               
  70.             //求sum数组中每个维度的平均值,也就是新的质心  
  71.             StringBuffer sb = new StringBuffer();  
  72.             for(int i = 0; i < sum.length; i++) {  
  73.                 sum[i] /= size;  
  74.                 sb.append(sum[i]);  
  75.                 sb.append(”,”);  
  76.             }  
  77.               
  78.             /**判断新的质心跟老的质心是否是一样的*/  
  79.             boolean flag = true;  
  80.             String[] centerStrArray = key.toString().split(”,”);  
  81.             for(int i = 0; i < centerStrArray.length; i++) {  
  82.                 if(Math.abs(Double.parseDouble(centerStrArray[i]) - sum[i]) > 0.00000000001) {  
  83.                     flag = false;  
  84.                     break;  
  85.                 }  
  86.             }  
  87.             //如果新的质心跟老的质心是一样的,那么相应的计数器加1  
  88.             if(flag) {  
  89.                 counter = context.getCounter(”myCounter”“kmenasCounter”);  
  90.                 counter.increment(1l);  
  91.             }  
  92.             context.write(nullnew Text(sb.toString()));  
  93.         }  
  94.     }  
  95.   
  96.     public static void main(String[] args) throws Exception {  
  97.   
  98.         Path kMeansPath = new Path(“/dsap/middata/kmeans/kMeans”);  //初始的质心文件  
  99.         Path samplePath = new Path(“/dsap/middata/kmeans/sample”);  //样本文件  
  100.         //加载聚类中心文件  
  101.         Center center = new Center();  
  102.         String centerString = center.loadInitCenter(kMeansPath);  
  103.           
  104.         int index = 0;  //迭代的次数  
  105.         while(index < 5) {  
  106.               
  107.             Configuration conf = new Configuration();  
  108.             conf.set(FLAG, centerString);   //将聚类中心的字符串放到configuration中  
  109.               
  110.             kMeansPath = new Path(“/dsap/middata/kmeans/kMeans” + index);   //本次迭代的输出路径,也是下一次质心的读取路径  
  111.               
  112.             /**判断输出路径是否存在,如果存在,则删除*/  
  113.             FileSystem hdfs = FileSystem.get(conf);  
  114.             if(hdfs.exists(kMeansPath)) hdfs.delete(kMeansPath);  
  115.   
  116.             Job job = new Job(conf, “kmeans” + index);   
  117.             job.setJarByClass(KmeansMR.class);  
  118.             job.setMapperClass(TokenizerMapper.class);  
  119.             job.setReducerClass(IntSumReducer.class);  
  120.             job.setOutputKeyClass(NullWritable.class);  
  121.             job.setOutputValueClass(Text.class);  
  122.             job.setMapOutputKeyClass(Text.class);  
  123.             job.setMapOutputValueClass(Text.class);  
  124.             FileInputFormat.addInputPath(job, samplePath);  
  125.             FileOutputFormat.setOutputPath(job, kMeansPath);  
  126.             job.waitForCompletion(true);  
  127.               
  128.             /**获取自定义counter的大小,如果等于质心的大小,说明质心已经不会发生变化了,则程序停止迭代*/  
  129.             long counter = job.getCounters().getGroup(“myCounter”).findCounter(“kmenasCounter”).getValue();  
  130.             if(counter == Center.k) System.exit(0);  
  131.             /**重新加载质心*/  
  132.             center = new Center();  
  133.             centerString = center.loadCenter(kMeansPath);  
  134.               
  135.             index ++;  
  136.         }  
  137.         System.exit(0);  
  138.     }  
  139.       
  140.     public static double distance(double[] a, double[] b) {  
  141.           
  142.         if(a == null || b == null || a.length != b.length) return Double.MAX_VALUE;  
  143.         double dis = 0;  
  144.         for(int i = 0; i < a.length; i++) {  
  145.             dis += Math.pow(a[i] - b[i], 2);  
  146.         }  
  147.         return Math.sqrt(dis);  
  148.     }  
  149. }     

5. 结果

产生了两个文件夹,分别是第一次、第二次迭代后的聚类中心


最后的聚类中心的内容如下


from: http://blog.csdn.net/nwpuwyk/article/details/29564249?utm_source=tuicool&utm_medium=referral

你可能感兴趣的:(机器学习)