hadoop mapreduce v1接口实现自定义inputformat,mysql作为输入

工作需要,自定义实现hadoop的一个inputformat,使用v1的接口(org.apache.hadoop.mapred),此inputformat的功能为读取mysql数据库的数据,将这些数据分成几块作为多个InputSplit,

package com.demo7;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.Statement;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.InputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.RecordReader;
import org.apache.hadoop.mapred.Reporter;
import org.apache.log4j.Logger;

import com.google.common.base.Joiner;

public class MysqlInputformat implements InputFormat<Text,Text>{
    private static Logger logger = Logger.getLogger(MysqlInputformat.class);
    
    private String beginTradeDay = Config.getConfig().getProperty("dealday.begin.mysqlinputformat");
    private String endTradeDay = Config.getConfig().getProperty("dealday.end.mysqlinputformat");
    private int oneTaskStocks = Integer.valueOf(Config.getConfig().getProperty("stocknumber_permap.mysqlinputformat"));
        
    @Override
    public RecordReader<Text, Text> getRecordReader(InputSplit split,
            JobConf conf, Reporter reporter) throws IOException {
        MysqlInputSplit mysqlSplit = (MysqlInputSplit)split;
        logger.info("---------------------ln------------------------");
        logger.info(mysqlSplit.tradeDay);
        logger.info(mysqlSplit.stockcodes);
        return new MysqlRecordReader(mysqlSplit.tradeDay, mysqlSplit.stockcodes);
    }

    @Override
    public InputSplit[] getSplits(JobConf arg0, int arg1) throws IOException {
        ArrayList<InputSplit> splits = new ArrayList<InputSplit>();        
        
        logger.info(String.format("begin generate map task, from %s to %s", beginTradeDay, endTradeDay));
        HashMap<String,ArrayList<String>> dayStocks = new  HashMap<String,ArrayList<String>>(); //key为交易日,value为股票列表
        Connection conn = null;
        try {
            conn = DriverManager.getConnection(MysqlProxy.getProxoolUrl(), MysqlProxy.getProxoolConf());
            // 创建一个Statement对象
            Statement stmt = conn.createStatement(); // 创建Statement对象
            String sql = String.format( 
                    "select date,stock_code from gushitong.s_kline_day_complexbeforeright_ln " +
                    " where date>='%s' and date<='%s'",
                    beginTradeDay, endTradeDay);
            ResultSet rs = stmt.executeQuery(sql);// 创建数据对象
            String date = null;
            String stockcode = null;
            while (rs.next()) {
                date = rs.getString("date");
                stockcode = rs.getString("stock_code");
                if(dayStocks.containsKey(date) == false){
                    dayStocks.put(date, new ArrayList<String>(3300));
                }
                dayStocks.get(date).add(stockcode);
            }
            rs.close();
            stmt.close();
            conn.close();
        } catch (Exception e) {
            logger.error(e);
        }

        Joiner joiner = Joiner.on(":").useForNull("");
        SimpleDateFormat sdf_1 = new SimpleDateFormat("yyyyMMdd");
        SimpleDateFormat sdf_2 = new SimpleDateFormat("yyyy-MM-dd");
        
        for(Map.Entry<String, ArrayList<String>> dayStockEntry : dayStocks.entrySet()){
            String tradeDay = dayStockEntry.getKey();
            for(int i=0; i<dayStockEntry.getValue().size();){
                int endindex;
                if(i+oneTaskStocks<=dayStockEntry.getValue().size()){
                    endindex = i+oneTaskStocks;
                }else{
                    endindex = dayStockEntry.getValue().size();
                }
                String stocks = joiner.join(dayStockEntry.getValue().subList(i, endindex));
                i = endindex;

                try {
                    MysqlInputSplit split = new MysqlInputSplit();
                    split.tradeDay = sdf_2.format(sdf_1.parse(tradeDay));
                    split.stockcodes = stocks;

                    splits.add(split);
                } catch (ParseException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }                
             }
         }
        
        InputSplit[] rtn = splits.toArray(new InputSplit[splits.size()]);
        return rtn;
    }

    public static class MysqlInputSplit implements InputSplit{
        
        public String getTradeDay() {
            return tradeDay;
        }

        public void setTradeDay(String tradeDay) {
            this.tradeDay = tradeDay;
        }

        public String getStockcodes() {
            return stockcodes;
        }

        public void setStockcodes(String stockcodes) {
            this.stockcodes = stockcodes;
        }

        private String tradeDay = null;
        private String stockcodes = null;
        
        @Override
        public void readFields(DataInput in) throws IOException {
            this.tradeDay = Text.readString(in);
            this.stockcodes = Text.readString(in);
        }

        @Override
        public void write(DataOutput out) throws IOException {
            Text.writeString(out, tradeDay);
            Text.writeString(out, stockcodes);
            
        }

        @Override
        public long getLength() throws IOException {
            // TODO Auto-generated method stub
            return 0;
        }

        @Override
        public String[] getLocations() throws IOException {
            String[] arr = {"aa"};      //必须有,因为不管有没有用,框架都要用。
            return arr;
        }
        
    }
    
    public static class MysqlRecordReader implements RecordReader<Text, Text>{

        public String tradeDay = null;
        public String stockcodes = null;
        private boolean isDeal = false;
        private long begintimeLong = new Date().getTime();
        
        public MysqlRecordReader(String tradeDay, String stockcodes){
            this.tradeDay = tradeDay;
            this.stockcodes = stockcodes;
        }
        
        @Override
        public void close() throws IOException {
            // TODO Auto-generated method stub
            
        }

        @Override
        public Text createKey() {
            return new Text();
        }

        @Override
        public Text createValue() {
            return new Text();
        }

        @Override
        public long getPos() throws IOException {
            // TODO Auto-generated method stub
            return 0;
        }

        /**
         * 预计15小时为100%
         */
        @Override
        public float getProgress() throws IOException {
            logger.info(String.format("get process is %f", ((float)(new Date().getTime() - this.begintimeLong))/(float)(15*3600*1000)));
            return Math.min(0.9f, ((float)(new Date().getTime() - this.begintimeLong))/(float)(15*3600*1000));
        }

        @Override
        public synchronized boolean next(Text key, Text value) throws IOException {
            if(this.isDeal == true){
                return false;
            }else{
                this.isDeal = true;
            }
            
            key.set(this.tradeDay);
            value.set(this.stockcodes);
            
            return true;
        }
        
    }
}

你可能感兴趣的:(mapreduce,hadoop,inputformat)