MapReduce Kmeans聚类算法

最近在网上查看用MapReduce实现的Kmeans算法,例子是不错,http://blog.csdn.net/jshayzf/article/details/22739063

但注释太少了,而且参数太多,如果新手学习的话不太好理解。所以自己按照个人的理解写了一个简单的例子并添加了详细的注释。

大致的步骤是:

1,Map每读取一条数据就与中心做对比,求出该条记录对应的中心,然后以中心的ID为Key,该条数据为value将数据输出。

2,利用reduce的归并功能将相同的Key归并到一起,集中与该Key对应的数据,再求出这些数据的平均值,输出平均值。

3,对比reduce求出的平均值与原来的中心,如果不相同,这将清空原中心的数据文件,将reduce的结果写到中心文件中。(中心的值存在一个HDFS的文件中)

     删掉reduce的输出目录以便下次输出。

     继续运行任务。

4,对比reduce求出的平均值与原来的中心,如果相同。则删掉reduce的输出目录,运行一个没有reduce的任务将中心ID与值对应输出。

  1 package MyKmeans;

  2 

  3 import java.io.IOException;

  4 import java.util.ArrayList;

  5 

  6 import org.apache.hadoop.conf.Configuration;

  7 import org.apache.hadoop.fs.Path;

  8 import org.apache.hadoop.io.Text;

  9 

 10 import java.util.Arrays;

 11 import java.util.Iterator;

 12 

 13 import org.apache.hadoop.io.IntWritable;

 14 import org.apache.hadoop.io.LongWritable;

 15 import org.apache.hadoop.mapreduce.Job;

 16 import org.apache.hadoop.mapreduce.Mapper;

 17 import org.apache.hadoop.mapreduce.Reducer;

 18 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;

 19 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

 20 

 21 

 22 public class MapReduce {

 23     

 24     public static class Map extends Mapper<LongWritable, Text, IntWritable, Text>{

 25 

 26         //中心集合

 27         ArrayList<ArrayList<Double>> centers = null;

 28         //用k个中心

 29         int k = 0;

 30         

 31         //读取中心

 32         protected void setup(Context context) throws IOException,

 33                 InterruptedException {

 34             centers = Utils.getCentersFromHDFS(context.getConfiguration().get("centersPath"),false);

 35             k = centers.size();

 36         }

 37 

 38 

 39         /**

 40          * 1.每次读取一条要分类的条记录与中心做对比,归类到对应的中心

 41          * 2.以中心ID为key,中心包含的记录为value输出(例如: 1 0.2 。  1为聚类中心的ID,0.2为靠近聚类中心的某个值)

 42          */

 43         protected void map(LongWritable key, Text value, Context context)

 44                 throws IOException, InterruptedException {

 45             //读取一行数据

 46             ArrayList<Double> fileds = Utils.textToArray(value);

 47             int sizeOfFileds = fileds.size();

 48             

 49             double minDistance = 99999999;

 50             int centerIndex = 0;

 51             

 52             //依次取出k个中心点与当前读取的记录做计算

 53             for(int i=0;i<k;i++){

 54                 double currentDistance = 0;

 55                 for(int j=0;j<sizeOfFileds;j++){

 56                     double centerPoint = Math.abs(centers.get(i).get(j));

 57                     double filed = Math.abs(fileds.get(j));

 58                     currentDistance += Math.pow((centerPoint - filed) / (centerPoint + filed), 2);

 59                 }

 60                 //循环找出距离该记录最接近的中心点的ID

 61                 if(currentDistance<minDistance){

 62                     minDistance = currentDistance;

 63                     centerIndex = i;

 64                 }

 65             }

 66             //以中心点为Key 将记录原样输出

 67             context.write(new IntWritable(centerIndex+1), value);

 68         }

 69         

 70     }

 71     

 72     //利用reduce的归并功能以中心为Key将记录归并到一起

 73     public static class Reduce extends Reducer<IntWritable, Text, Text, Text>{

 74 

 75         /**

 76          * 1.Key为聚类中心的ID value为该中心的记录集合

 77          * 2.计数所有记录元素的平均值,求出新的中心

 78          */

 79         protected void reduce(IntWritable key, Iterable<Text> value,Context context)

 80                 throws IOException, InterruptedException {

 81             ArrayList<ArrayList<Double>> filedsList = new ArrayList<ArrayList<Double>>();

 82             

 83             //依次读取记录集,每行为一个ArrayList<Double>

 84             for(Iterator<Text> it =value.iterator();it.hasNext();){

 85                 ArrayList<Double> tempList = Utils.textToArray(it.next());

 86                 filedsList.add(tempList);

 87             }

 88             

 89             //计算新的中心

 90             //每行的元素个数

 91             int filedSize = filedsList.get(0).size();

 92             double[] avg = new double[filedSize];

 93             for(int i=0;i<filedSize;i++){

 94                 //求没列的平均值

 95                 double sum = 0;

 96                 int size = filedsList.size();

 97                 for(int j=0;j<size;j++){

 98                     sum += filedsList.get(j).get(i);

 99                 }

100                 avg[i] = sum / size;

101             }

102             context.write(new Text("") , new Text(Arrays.toString(avg).replace("[", "").replace("]", "")));

103         }

104         

105     }

106     

107     @SuppressWarnings("deprecation")

108     public static void run(String centerPath,String dataPath,String newCenterPath,boolean runReduce) throws IOException, ClassNotFoundException, InterruptedException{

109         

110         Configuration conf = new Configuration();

111         conf.set("centersPath", centerPath);

112         

113         Job job = new Job(conf, "mykmeans");

114         job.setJarByClass(MapReduce.class);

115         

116         job.setMapperClass(Map.class);

117 

118         job.setMapOutputKeyClass(IntWritable.class);

119         job.setMapOutputValueClass(Text.class);

120 

121         if(runReduce){

122             //最后依次输出不许要reduce

123             job.setReducerClass(Reduce.class);

124             job.setOutputKeyClass(Text.class);

125             job.setOutputValueClass(Text.class);

126         }

127         

128         FileInputFormat.addInputPath(job, new Path(dataPath));

129         

130         FileOutputFormat.setOutputPath(job, new Path(newCenterPath));

131         

132         System.out.println(job.waitForCompletion(true));

133     }

134 

135     public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {

136         String centerPath = "hdfs://localhost:9000/input/centers.txt";

137         String dataPath = "hdfs://localhost:9000/input/wine.txt";

138         String newCenterPath = "hdfs://localhost:9000/out/kmean";

139         

140         int count = 0;

141         

142         

143         while(true){

144             run(centerPath,dataPath,newCenterPath,true);

145             System.out.println(" 第 " + ++count + " 次计算 ");

146             if(Utils.compareCenters(centerPath,newCenterPath )){

147                 run(centerPath,dataPath,newCenterPath,false);

148                 break;

149             }

150         }

151     }

152     

153 }
  1 package MyKmeans;

  2 

  3 import java.io.IOException;

  4 import java.util.ArrayList;

  5 import java.util.List;

  6 

  7 import org.apache.hadoop.conf.Configuration;

  8 import org.apache.hadoop.fs.FSDataInputStream;

  9 import org.apache.hadoop.fs.FSDataOutputStream;

 10 import org.apache.hadoop.fs.FileStatus;

 11 import org.apache.hadoop.fs.FileSystem;

 12 import org.apache.hadoop.fs.Path;

 13 import org.apache.hadoop.io.IOUtils;

 14 import org.apache.hadoop.io.Text;

 15 import org.apache.hadoop.util.LineReader;

 16 

 17 public class Utils {

 18     

 19     //读取中心文件的数据

 20     public static ArrayList<ArrayList<Double>> getCentersFromHDFS(String centersPath,boolean isDirectory) throws IOException{

 21         

 22         ArrayList<ArrayList<Double>> result = new ArrayList<ArrayList<Double>>();

 23         

 24         Path path = new Path(centersPath);

 25         

 26         Configuration conf = new Configuration();

 27         

 28         FileSystem fileSystem = path.getFileSystem(conf);

 29 

 30         if(isDirectory){    

 31             FileStatus[] listFile = fileSystem.listStatus(path);

 32             for (int i = 0; i < listFile.length; i++) {

 33                 result.addAll(getCentersFromHDFS(listFile[i].getPath().toString(),false));

 34             }

 35             return result;

 36         }

 37         

 38         FSDataInputStream fsis = fileSystem.open(path);

 39         LineReader lineReader = new LineReader(fsis, conf);

 40         

 41         Text line = new Text();

 42         

 43         while(lineReader.readLine(line) > 0){

 44             ArrayList<Double> tempList = textToArray(line);

 45             result.add(tempList);

 46         }

 47         lineReader.close();

 48         return result;

 49     }

 50     

 51     //删掉文件

 52     public static void deletePath(String pathStr) throws IOException{

 53         Configuration conf = new Configuration();

 54         Path path = new Path(pathStr);

 55         FileSystem hdfs = path.getFileSystem(conf);

 56         hdfs.delete(path ,true);

 57     }

 58     

 59     public static ArrayList<Double> textToArray(Text text){

 60         ArrayList<Double> list = new ArrayList<Double>();

 61         String[] fileds = text.toString().split(",");

 62         for(int i=0;i<fileds.length;i++){

 63             list.add(Double.parseDouble(fileds[i]));

 64         }

 65         return list;

 66     }

 67     

 68     public static boolean compareCenters(String centerPath,String newPath) throws IOException{

 69         

 70         List<ArrayList<Double>> oldCenters = Utils.getCentersFromHDFS(centerPath,false);

 71         List<ArrayList<Double>> newCenters = Utils.getCentersFromHDFS(newPath,true);

 72         

 73         int size = oldCenters.size();

 74         int fildSize = oldCenters.get(0).size();

 75         double distance = 0;

 76         for(int i=0;i<size;i++){

 77             for(int j=0;j<fildSize;j++){

 78                 double t1 = Math.abs(oldCenters.get(i).get(j));

 79                 double t2 = Math.abs(newCenters.get(i).get(j));

 80                 distance += Math.pow((t1 - t2) / (t1 + t2), 2);

 81             }

 82         }

 83         

 84         if(distance == 0.0){

 85             //删掉新的中心文件以便最后依次归类输出

 86             Utils.deletePath(newPath);

 87             return true;

 88         }else{

 89             //先清空中心文件,将新的中心文件复制到中心文件中,再删掉中心文件

 90             

 91             Configuration conf = new Configuration();

 92             Path outPath = new Path(centerPath);

 93             FileSystem fileSystem = outPath.getFileSystem(conf);

 94             

 95             FSDataOutputStream overWrite = fileSystem.create(outPath,true);

 96             overWrite.writeChars("");

 97             overWrite.close();

 98             

 99             

100             Path inPath = new Path(newPath);

101             FileStatus[] listFiles = fileSystem.listStatus(inPath);

102             for (int i = 0; i < listFiles.length; i++) {                

103                 FSDataOutputStream out = fileSystem.create(outPath);

104                 FSDataInputStream in = fileSystem.open(listFiles[i].getPath());

105                 IOUtils.copyBytes(in, out, 4096, true);

106             }

107             //删掉新的中心文件以便第二次任务运行输出

108             Utils.deletePath(newPath);

109         }

110         

111         return false;

112     }

113 }

数据集   http://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data

运行结果可以与 http://blog.csdn.net/jshayzf/article/details/22739063的结果做对比(前提是初始的中心相同)

 

你可能感兴趣的:(mapreduce)