关于在mapreduce框架中的两个矩阵相乘(A*B)的算法实现,有如下两种思路。。
第一,因为我们在学校课堂内的矩阵相乘的基本算法就是A的行与B的列相乘 当然要满足A的列的维数与B的行维数相同,才能满足相乘的条件。所以有如下基本思路:
让每个map任务计算A的一行乘以B的一列,最后由reduce进行求和输出。这是最原始的实现方法:
假设A(m*n) B(n*s)
map的输入的格式如下<<x,y>,<Ax,By>> 0=<x<m,0=<y<s,0=<z<n
其中 <x,y>是key,x代表A的行号,y代表B的列号,<<Ax,By>>是value,Ax代表A的第x行第z列的元素,By代表B的第y列的第z行的一个元素,
A的一行与B的一列输入到一个maptask中,我们只需要对每个键值对中的value的两个值相乘即可,输出一个<<x,y>,Ax*By>
然后到洗牌阶段,将相同的可以输入到一个Reduce task中,然后reduce只需对相同key的value列表进行Ax*By进行求和即可。这个算法说起来比较简单,但是如何控制split中的内容是主要的问题。
首先需要重写InputSplit,InputFormat,Partion,来控制数据的流动,在数据结构方面需要定义一个实现的WritableComparable借口的类来保存两个整数(因为前面的key和value都出现两个整数),而且对象可以排序。
IntPair.class实现:
package com.zxx.matrix; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import org.apache.hadoop.io.WritableComparable; public class IntPair implements WritableComparable { private int right=0; private int left=0; public IntPair(){} public IntPair(int right,int left){ this.right=right; this.left=left; } public int getRight(){ return right; } public int getLeft(){ return left; } public void setRight(int right){ this.right=right; } public void setLeft(int left){ this.left=left; } public String toString(){ return left+","+right; } @Override public void readFields(DataInput arg0) throws IOException { // TODO Auto-generated method stub right=arg0.readInt(); left=arg0.readInt(); } @Override public void write(DataOutput arg0) throws IOException { // TODO Auto-generated method stub arg0.writeInt(right); arg0.writeInt(left); } @Override public int compareTo(Object arg0) { // TODO Auto-generated method stub IntPair o=(IntPair)arg0; if(this.right<o.getRight()) { return -1; }else if (this.right>o.getRight()) { return 1; }else if (this.left<o.getLeft()) { return -1; }else if (this.left>o.getLeft()) { return 1; } return 0; } }
InputSplit.class(样例)
在这个类中用一个ArrayWritable 来保存元素的位置信息以及具体的元素信息
public class matrixInputSplit extends InputSplit implements Writable { private IntPair[] t;//具体元素信息 private IntPair location;//key的值,元素位置信息 private ArrayWritable intPairArray; public matrixInputSplit() { } public matrixInputSplit(int row,matrix left,int col,matrix right) { //填充intPairArray intPairArray=new ArrayWritable(IntPair.class); t=new IntPair[4]; location=new IntPair(row,col); for(int j=0;j<3;j++) { IntPair intPair=new IntPair(); intPair.setLeft(left.m[row][j]); intPair.setRight(right.m[j][col]); t[j]=intPair; } t[3]=location; intPairArray.set(t); } @Override public long getLength() throws IOException, InterruptedException { return 0; } @Override public String[] getLocations() throws IOException, InterruptedException { return new String[]{}; //返回空 这样JobClient就不会从文件中读取split } @Override public void readFields(DataInput arg0) throws IOException { this.intPairArray=new ArrayWritable(IntPair.class); this.intPairArray.readFields(arg0); } @Override public void write(DataOutput arg0) throws IOException { /*arg0.writeInt(t.length); for(int i=0;i<t.length;i++) { t[i].write(arg0); }*/ intPairArray.write(arg0); } public IntPair getLocation() { t=new IntPair[4]; try { t=(IntPair[])intPairArray.toArray(); } catch (Exception e) { System.out.println("toArray excption"); } return t[3]; } public IntPair[] getIntPairs() { t=new IntPair[4]; try { t=(IntPair[])intPairArray.toArray(); } catch (Exception e) { System.out.println("toArray excption"); } IntPair[] intL=new IntPair[3]; for(int i=0;i<3;i++) { intL[i]=t[i]; } return intL; } }
Inputformat.class
这个类比较简单,只需要实现getSplit方法即可,不过需要用户自定义一个方法就是从getInputfile获得的路径来解析矩阵,输入到split中即可。
matrixMul.class
public class MatrixNew { public static class MatrixMapper extends Mapper<IntPair, IntPair, IntPair, IntWritable> { public void map(IntPair key, IntPair value, Context context) { int left=0 ; int right=0; System.out.println("map is do"); left = value.getLeft(); right = value.getRight(); IntWritable result = new IntWritable(left * right); // key不变, // value中的两个int相乘 try { context.write(key, result); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); } // 输出kv对 } } public static class MatrixReducer extends Reducer<IntPair, IntWritable, IntPair, IntWritable> { private IntWritable result = new IntWritable(); public void reduce(IntPair key, Iterable<IntWritable> values, Context context) { int sum = 0; for (IntWritable val : values) { int v = val.get(); sum += v; } result.set(sum); try { context.write(key, result); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); } } } public static class FirstPartitioner extends Partitioner<IntPair, IntWritable> { public int getPartition(IntPair key, IntWritable value, int numPartitions) { int abs = Math.abs(key.getLeft()) % numPartitions; // numPartitions是reduce线程的数量 return abs; } } public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { Configuration conf=new Configuration(); new GenericOptionsParser(conf, args); FileSystem fs=FileSystem.get(conf); Job job = new Job(conf, "New Matrix Multiply Job "); job.setJarByClass(MatrixNew.class); job.setNumReduceTasks(1); job.setInputFormatClass(matrixInputFormat.class); job.setOutputFormatClass(TextOutputFormat.class); job.setMapperClass(MatrixMapper.class); job.setReducerClass(MatrixReducer.class); job.setPartitionerClass(FirstPartitioner.class); job.setMapOutputKeyClass(IntPair.class); job.setMapOutputValueClass(IntWritable.class); job.setOutputKeyClass(IntPair.class); job.setOutputValueClass(IntWritable.class); matrixInputFormat.setInputPath(args[0]); FileOutputFormat.setOutputPath(job,new Path(fs.makeQualified(new Path("/newMartixoutput")).toString())); boolean ok = job.waitForCompletion(true); if(ok){ //删除临时文件 } } }
以上代码只是简单测试下。。如有问题欢迎大家指正!这里先谢过!
第二个方法就是矩阵分块相乘,这个算法网上有大牛已经给出了源代码。。。