Spark 装载 MySQL的数据

使用Spark 加载 MySQL的数据来进行分析处理。。

多余的就不解释了,,,直接捞干的。


InputFormat实现类代码如下:

package com.rayn.spark.plugin.mysql.impl;

import com.rayn.spark.plugin.mysql.MysqlMapReduce;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * @author Rayn on 2015/12/2 9:37
 * @email [email protected]
 */
public class MysqlInputFormat<T> extends InputFormat<Text, T> implements MysqlMapReduce
{

    /**
     * Logically split the set of input files for the job.
     * <p/>
     * <p>Each {@link InputSplit} is then assigned to an individual {@link Mapper}
     * for processing.</p>
     * <p/>
     * <p><i>Note</i>: The split is a <i>logical</i> split of the inputs and the
     * input files are not physically split into chunks. For e.g. a split could
     * be <i>&lt;input-file-path, start, offset&gt;</i> tuple. The InputFormat
     * also creates the {@link RecordReader} to read the {@link InputSplit}.
     *
     * @param context job configuration.
     * @return an array of {@link InputSplit}s for the job.
     */
    @Override
    public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException
    {
        List<InputSplit> inputSplits = new ArrayList<InputSplit>();

        Configuration conf = context.getConfiguration();

        String driver = conf.get(DRIVER);
        String url = conf.get(URL);
        String username = conf.get(USERNAME);
        String password = conf.get(PASSWORD);
        String sql = conf.get(SQL);
        int type = Integer.parseInt(conf.get(DATABASE_TYPE));

        MysqlInputSplit inputSplit = new MysqlInputSplit(driver, url, username, password, sql, type);

        inputSplits.add(inputSplit);
        return inputSplits;
    }

    /**
     * Create a record reader for a given split. The framework will call
     * {@link RecordReader#initialize(InputSplit, TaskAttemptContext)} before
     * the split is used.
     *
     * @param split   the split to be read
     * @param context the information about the task
     * @return a new record reader
     * @throws IOException
     * @throws InterruptedException
     */
    @Override
    public RecordReader<Text, T> createRecordReader(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException
    {
        Configuration conf = context.getConfiguration();


        RecordReader reader = new MysqlStringRecordReader();

        return reader;
    }
}

然后是MySQLStringRecordReader的实现类如下:

package com.rayn.spark.plugin.mysql.impl;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;

import java.io.IOException;
import java.net.InetAddress;
import java.sql.*;
import java.util.ArrayDeque;
import java.util.Queue;

/**
 *
 *
 *
 * @author Rayn on 2015/11/25 9:47
 * @email [email protected]
 */
public class MysqlStringRecordReader extends RecordReader<Text, Text> {


    private static final Log LOG = LogFactory.getLog(MysqlStringRecordReader.class);

    private Connection conn = null;
    private Statement stmt = null;
    private ResultSet resultSet = null;

    private Queue<Object[]> queue = null;

    private Text key;
    protected Text value;

    private long totalKVs = 0;
    private long processedKVs = 0;

    private int cursor = 0;

    public MysqlStringRecordReader() {
    }


    /**
     *
     * @param genericSplit
     * @param context
     * @throws IOException
     * @throws InterruptedException
     */
    @Override
    public void initialize(InputSplit genericSplit, TaskAttemptContext context) throws IOException, InterruptedException {
        MysqlInputSplit split = (MysqlInputSplit) genericSplit;

        LOG.info("MySQLArrayRecordReader  ::  initialize");
        LOG.info("mysql address = [" + split.getLocations() + "]");
        LOG.info("mysql IP = [" + InetAddress.getLocalHost().getHostAddress() + "]");

        queue = new ArrayDeque<Object[]>();
        try
        {
            Class.forName(split.getDriver());
            conn = DriverManager.getConnection(split.getUrl(), split.getUsername(), split.getPassowrd());
            stmt = conn.createStatement();
            resultSet = stmt.executeQuery(split.getSql());

            resultSet.last();
            totalKVs = resultSet.getRow();

            resultSet.first();

        }
        catch (Exception e)
        {
            LOG.error("创建SQL链接异常。");
            e.printStackTrace();
            try
            {
                if(null != resultSet){
                    resultSet.close();
                }
                if(null != stmt){
                    stmt.close();
                }
                if(null != conn){
                    conn.close();
                }
            }
            catch (SQLException e2)
            {
                e2.printStackTrace();
            }
        }
    }

    @Override
    public boolean nextKeyValue()
    {
        this.key = null;
        this.value = null;
        try
        {
            if(resultSet.next()){
                processedKVs = resultSet.getRow();
                this.key = new Text(String.valueOf(resultSet.getInt(1)));
                this.value = new Text(resultSet.getString(2));
                return true;
            }

            while (cursor > 0) {
                while (resultSet.next()) {
                    processedKVs = resultSet.getRow();
                    this.key = new Text(String.valueOf(resultSet.getInt(1)));
                    this.value = new Text(resultSet.getString(2));
                    return true;
                }
            }
        }
        catch (SQLException e)
        {
            e.printStackTrace();
        }
        return false;
    }


    @Override
    public Text getCurrentKey() throws IOException, InterruptedException {
        return key;
    }


    @Override
    public Text getCurrentValue() throws IOException, InterruptedException {
        return value;
    }


    @Override
    public float getProgress() throws IOException, InterruptedException {
        return totalKVs == 0 ? 1.0F : (processedKVs * 1.0F / totalKVs);
    }

    @Override
    public synchronized void close() throws IOException {
        try
        {
            if(null != resultSet){
                resultSet.close();
            }
            if(null != stmt){
                stmt.close();
            }
            if(null != conn){
                conn.close();
            }
        }
        catch (SQLException e)
        {
            e.printStackTrace();
        }
    }
 }

最后再是 MySQLInputSplit的实现类

package com.rayn.spark.plugin.mysql.impl;

import com.rayn.spark.plugin.mysql.MysqlMapReduce;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputSplit;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.Serializable;

/**
 *
 * @author Rayn on 2015/12/2 9:39
 * @email [email protected]
 */
public class MysqlInputSplit extends InputSplit implements Writable, Serializable
{

    private final static long serialVersionUID = 0L;

    private String driver = "com.mysql.jdbc.Driver";
    private String host = "localhost";
    private String port = "3306";
    private String database = "test";
    private String url = "jdbc:mysql://" + host + ":" + port + "/" + database + "?autoReconnect=true&autoReconnectForPools=true&useUnicode=true&generateSimpleParameterMetadata=true&characterEncoding=utf8";
    private String username = "root";
    private String passowrd = "";
    private String tablename = "test";
    private String sql = "";

    public MysqlInputSplit()
    {
    }

    public MysqlInputSplit(String driver, String url, String username, String passowrd, String sql, int type)
    {
        this.driver = driver;
        this.url = url;
        this.username = username;
        this.passowrd = passowrd;
        this.sql = sql;


        if(type == MysqlMapReduce.DatabaseType.MYSQL.getType()){
            String linkStr = url.substring(url.indexOf("//") + 2, url.indexOf("?"));
            String[] split = StringUtils.split(port, ":");
            this.host = split[0];
            this.port = split[1].split("/")[0];
            this.database = split[1].split("/")[1];
        }

    }

    public MysqlInputSplit(String driver, String host, String port, String url, String username, String passowrd, String tablename, String sql)
    {
        this.driver = driver;
        this.host = host;
        this.port = port;
        this.url = url;
        this.username = username;
        this.passowrd = passowrd;
        this.tablename = tablename;
        this.sql = sql;
    }

    public String getDriver()
    {
        return driver;
    }

    public void setDriver(String driver)
    {
        this.driver = driver;
    }

    public String getHost()
    {
        return host;
    }

    public void setHost(String host)
    {
        this.host = host;
    }

    public String getPort()
    {
        return port;
    }

    public void setPort(String port)
    {
        this.port = port;
    }

    public String getUrl()
    {
        return url;
    }

    public void setUrl(String url)
    {
        this.url = url;
    }

    public String getUsername()
    {
        return username;
    }

    public void setUsername(String username)
    {
        this.username = username;
    }

    public String getPassowrd()
    {
        return passowrd;
    }

    public void setPassowrd(String passowrd)
    {
        this.passowrd = passowrd;
    }

    public String getTablename()
    {
        return tablename;
    }

    public void setTablename(String tablename)
    {
        this.tablename = tablename;
    }

    public String getSql()
    {
        return sql;
    }

    public void setSql(String sql)
    {
        this.sql = sql;
    }

    /**
     * Get the size of the split, so that the input splits can be sorted by size.
     *
     * @return the number of bytes in the split
     * @throws IOException
     * @throws InterruptedException
     */
    @Override
    public long getLength() throws IOException, InterruptedException
    {
        return 1;
    }

    /**
     * Get the list of nodes by name where the data for the split would be local.
     * The locations do not need to be serialized.
     *
     * @return a new array of the node nodes.
     * @throws IOException
     * @throws InterruptedException
     */
    @Override
    public String[] getLocations() throws IOException, InterruptedException
    {
        return new String[]{host};
    }

    @Override
    public void write(DataOutput dataOutput) throws IOException
    {
        dataOutput.writeUTF(driver);
        dataOutput.writeUTF(host);
        dataOutput.writeUTF(port);
        dataOutput.writeUTF(url);
        dataOutput.writeUTF(username);
        dataOutput.writeUTF(passowrd);
        dataOutput.writeUTF(tablename);
        dataOutput.writeUTF(sql);
    }

    @Override
    public void readFields(DataInput dataInput) throws IOException
    {
        dataInput.readUTF();
        dataInput.readUTF();
        dataInput.readUTF();
        dataInput.readUTF();
        dataInput.readUTF();
        dataInput.readUTF();
        dataInput.readUTF();
        System.out.println(dataInput);
    }
}


接下来就是使用 SparkContext接口来读取数据库进行数据操作分析。给出示例程序如下:

JavaPairRDD<Text, Text> rdd = sc.newAPIHadoopRDD(configuration, MysqlInputFormat.class, Text.class, Text.class);


直接可以运行,会读取到数据库中的数据信息。。。运行结果如下:

Spark 装载 MySQL的数据

你可能感兴趣的:(spark)