在使用hadoop做map/reduce时,有很多场景需要自行实现有多个属性的WritableComparable。以下示例希望对广大开发有所启示。
import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.io.WritableComparator; import org.apache.hadoop.io.WritableUtils; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; public class StatWritable implements WritableComparable<StatWritable> { private long timestamp; private int systemId; private String group; private String item; public String getGroup() { return group; } public void setGroup(final String group) { this.group = group; } public String getItem() { return item; } public void setItem(final String item) { this.item = item; } public int getSystemId() { return systemId; } public void setSystemId(final int systemId) { this.systemId = systemId; } public long getTimestamp() { return timestamp; } public void setTimestamp(final long timestamp) { this.timestamp = timestamp; } @Override public int compareTo(final StatWritable o) { int cmp = new Long(timestamp).compareTo(o.getTimestamp()); if (cmp != 0) { return cmp; } cmp = systemId - o.getSystemId(); if (cmp != 0) { return cmp; } cmp = group.compareTo(o.getGroup()); if (cmp != 0) { return cmp; } return item.compareTo(o.getItem()); } /** * 此方法中写出数据的顺序必须与{@link StatWritable#readFields(java.io.DataInput)}的读取数据一致。 * 根据写入的属性类型调用{@link java.io.DataOutput}中对应的write方法。当写入属性不定长时,必须先写出此字符串的长度后,再写出真实数据 * * @param out * @throws IOException */ @Override public void write(final DataOutput out) throws IOException { out.writeLong(timestamp); out.writeInt(systemId); final byte[] groupBytes = group.getBytes(); WritableUtils.writeVInt(out, groupBytes.length); out.write(groupBytes, 0, groupBytes.length); final byte[] itemBytes = item.getBytes(); WritableUtils.writeVInt(out, itemBytes.length); out.write(itemBytes, 0, itemBytes.length); } /** * 此方法中读取数据的顺序必须与{@link StatWritable#write(java.io.DataOutput)}的写入数据一致 * 根据读取的属性类型调用{@link java.io.DataInput}中对应的read方法。当读取属性不定长时,必须先读取此字符串的长度后,再读取真实数据 * * @param in * @throws IOException */ @Override public void readFields(final DataInput in) throws IOException { timestamp = in.readLong(); systemId = in.readInt(); final int groupLength = WritableUtils.readVInt(in); byte[] groupBytes = new byte[groupLength]; in.readFully(groupBytes, 0, groupLength); group = new String(groupBytes); int itemLength = WritableUtils.readVInt(in); byte[] itemBytes = new byte[itemLength]; in.readFully(itemBytes, 0, itemLength); item = new String(itemBytes); } /** * 覆盖toString方法,以便记录到map输出文件或reduce输出文件文件 * * @return */ @Override public String toString() { return systemId + " " + timestamp + " " + group + " " + item; } /** * 此类为了hadoop快速进行数据比较而设。覆盖{@link com.unionpay.stat.hadoop.domain.StatWritable.Comparator#compare(byte[], int, int, byte[], int, int)}方法时, * 比较属性的顺序必须与{@link org.apache.hadoop.io.Writable#readFields(java.io.DataInput)}和{@link org.apache.hadoop.io.Writable#write(java.io.DataOutput)}中对属性进行读写操作的顺序一致 */ public static class Comparator extends WritableComparator { protected Comparator() { super(StatWritable.class); } @Override public int compare(final byte[] b1, final int s1, final int l1, final byte[] b2, final int s2, final int l2) { try { final long timestampL1 = readLong(b1, s1); final long timestampL2 = readLong(b2, s2); final int cmp1 = timestampL1 < timestampL2 ? -1 : (timestampL1 == timestampL2 ? 0 : 1); if (cmp1 != 0) { return cmp1; } final int startIndex1_1 = s1 + 8; final int startIndex1_2 = s2 + 8; final int systemId1 = readInt(b1, startIndex1_1); final int systemId2 = readInt(b2, startIndex1_2); final int cmp2 = systemId1 < systemId2 ? -1 : (systemId1 == systemId2 ? 0 : 1); if (cmp2 != 0) { return cmp2; } final int startIndex2_1 = startIndex1_1 + 4; final int startIndex2_2 = startIndex1_2 + 4; final int groupLength1 = WritableUtils.decodeVIntSize(b1[startIndex2_1]) + readVInt(b1, startIndex2_1); final int groupLength2 = WritableUtils.decodeVIntSize(b2[startIndex2_2]) + readVInt(b2, startIndex2_2); final int cmp3 = compareBytes(b1, startIndex2_1, groupLength1, b2, startIndex2_2, groupLength2); if (cmp3 != 0) { return cmp3; } final int startIndex3_1 = startIndex2_1 + groupLength1; final int startIndex3_2 = startIndex2_2 + groupLength2; final int itemLength1 = WritableUtils.decodeVIntSize(b1[startIndex3_1]) + readVInt(b1, startIndex3_1); final int itemLength2 = WritableUtils.decodeVIntSize(b2[startIndex3_2]) + readVInt(b2, startIndex3_2); return compareBytes(b1, startIndex3_1, itemLength1, b2, startIndex3_2, itemLength2); } catch (IOException e) { throw new RuntimeException(e); } } } /** * 注册到hadoop,以便其能识别到 */ static { WritableComparator.define(StatWritable.class, new Comparator()); } }