spark采用池化方案解决Task not serializable提高性能

注意最后附上完整代码
1、报错:

Exception in thread "main" org.apache.spark.SparkException: Task not serializable

2、问题代码:

SparkSession sparkSession = SparkSession.builder().appName("LogsHandler").master("local[*]").getOrCreate();
        Dataset<Row> dataset = sparkSession.read().parquet("file:///Users/admin/fsdownload/logs/part-00000-14718313-4d0c-43d7-aecb-6902bf3a0cbc-c000.snappy.parquet");
        dataset.createOrReplaceTempView("traffic_statistic");
        Dataset<Row> result = dataset.sqlContext().sql("select ip,sum(flow) as total_flow from traffic_statistic where flow is not null and ip is not null group by ip");
        result.na().drop("all", new String[]{"total_flow"});//参考https://blog.csdn.net/qq_39570355/article/details/117188897
        try {
        Class.forName("com.mysql.cj.jdbc.Driver");
            Connection connection = DriverManager.getConnection("jdbc:mysql://124.222.10.220:3306/spark_test?useSSL=false", "root", "Wp@123456");
            
            result.foreachPartition(item -> {
                try {
                    connection.setAutoCommit(false);
                    PreparedStatement preparedStatement = connection.prepareStatement("insert into traffic_statistic(ip,flow) values (?,?)");
                    while (item.hasNext()) {
                        Row next = item.next();
                        String ip = next.getString(0);
                        Double totalFlow = next.getDouble(1);
                        preparedStatement.setString(1, ip);
                        preparedStatement.setDouble(2, totalFlow == null?0:totalFlow);
                        preparedStatement.addBatch();
                    }
                    preparedStatement.executeBatch();
                    connection.commit();
                }catch (Exception e){
                    e.printStackTrace();
                }
            });
            connection.close();
        } catch (SQLException | ClassNotFoundException e) {
            e.printStackTrace();
        }

3、原因分析:官网有说明

dstream.foreachRDD { rdd =>
  val connection = createNewConnection()  // executed at the driver
  rdd.foreach { record =>
    connection.send(record) // executed at the worker
  }
}

这是不正确的,因为这需要将连接对象序列化并从驱动程序发送到工作程序。这样的连接对象很少能跨机器转移。该错误可能表现为序列化错误(连接对象不可序列化)、初始化错误(连接对象需要在worker处初始化)等。正确的解决方案是在worker处创建连接对象。
但是,这可能会导致另一个常见错误——为每条记录创建一个新连接。例如,

dstream.foreachRDD { rdd =>
  rdd.foreach { record =>
    val connection = createNewConnection()
    connection.send(record)
    connection.close()
  }
}

通常,创建连接对象具有时间和资源开销。因此,为每条记录创建和销毁一个连接对象会导致不必要的高开销,并且会显着降低系统的整体吞吐量。更好的解决方案是使用 rdd.foreachPartition- 创建单个连接对象并使用该连接发送 RDD 分区中的所有记录。

dstream.foreachRDD { rdd =>
  rdd.foreachPartition { partitionOfRecords =>
    val connection = createNewConnection()
    partitionOfRecords.foreach(record => connection.send(record))
    connection.close()
  }
}

这分摊了许多记录的连接创建开销。
最后,这可以通过跨多个 RDD/批次重用连接对象来进一步优化。可以维护一个静态的连接对象池,当多批次的 RDD 被推送到外部系统时可以重复使用,从而进一步减少开销。

dstream.foreachRDD { rdd =>
  rdd.foreachPartition { partitionOfRecords =>
    // ConnectionPool is a static, lazily initialized pool of connections
    val connection = ConnectionPool.getConnection()
    partitionOfRecords.foreach(record => connection.send(record))
    ConnectionPool.returnConnection(connection)  // return to the pool for future reuse
  }
}

请注意,池中的连接应按需延迟创建,如果一段时间不使用则超时。这实现了向外部系统最有效地发送数据。
其他要记住的要点:
DStream 由输出操作延迟执行,就像 RDD 由 RDD 操作延迟执行一样。具体来说,DStream 输出操作中的 RDD 操作会强制处理接收到的数据。因此,如果您的应用程序没有任何输出操作,或者具有输出操作(例如dstream.foreachRDD()其中没有任何 RDD 操作),则不会执行任何操作。系统将简单地接收数据并将其丢弃。
默认情况下,输出操作一次执行一个。它们按照在应用程序中定义的顺序执行。

4、最后创建 连接池代码

public class ConnectionPoolableFactory  implements KeyedPoolableObjectFactory<ConnectionPoolKey,Connection> {
    private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionPoolableFactory.class);
    @Override
    public Connection makeObject(ConnectionPoolKey key) throws Exception {
        if (LOGGER.isInfoEnabled()) {
            LOGGER.info("Connection create path is " + key.getPath());
        }

        Class.forName("com.mysql.cj.jdbc.Driver");
        try {
            Connection connection = DriverManager.getConnection(key.getPath(), key.getUsername(), key.getPassword());
            connection.setAutoCommit(false);
            return connection;
        }catch (SQLException throwables){
            throw new RuntimeException("Connection create fail");
        }
    }
    @Override
    public void destroyObject(ConnectionPoolKey key, Connection connection) throws Exception {
        if (connection != null){
            LOGGER.info("will close Connection" + key.getPath());
            connection.close();
        }
    }
    @Override
    public boolean validateObject(ConnectionPoolKey key, Connection connection) {
        if (connection != null){
            return true;
        }
        if (LOGGER.isInfoEnabled()) {
            LOGGER.info("connection valid false,connection:" + key.getPath());
        }
        return false;
    }
    @Override
    public void activateObject(ConnectionPoolKey key, Connection connection) throws Exception {}
    @Override
    public void passivateObject(ConnectionPoolKey key, Connection connection) throws Exception {}
}
public class ConnectionPoolKey {
    private Integer connectionId = 0;
    private String path;
    private String username;
    private String password;

    public ConnectionPoolKey(String path, String username, String password) {
        this.path = path;
        this.username = username;
        this.password = password;
    }

    public Integer getConnectionId() {
        return connectionId;
    }

    public void setConnectionId(Integer connectionId) {
        this.connectionId = connectionId;
    }

    public String getPath() {
        return path;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public String getUsername() {
        return username;
    }

    public void setUsername(String username) {
        this.username = username;
    }

    public String getPassword() {
        return password;
    }

    public void setPassword(String password) {
        this.password = password;
    }
}
public class MysqlUtils {
    private static KeyedObjectPool pool = null;
    private static ConnectionPoolKey key = null;
    static {
        try {
            ConnectionPoolableFactory factory = new ConnectionPoolableFactory();
            pool = new StackKeyedObjectPool(factory);
            key = new ConnectionPoolKey("jdbc:mysql://数据库IP:端口/spark_test?useSSL=false","用户名","密码");
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("获取connection异常");
        }
    }

    public static Connection getConnection(){
        try {
            return (Connection)pool.borrowObject(key);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    public static void release(Connection connection) {
        try {
            pool.returnObject(key,connection);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}
public class LogsHandlerPool {
    public static void main(String[] args) throws ClassNotFoundException {
        //设置转换对象
        SparkSession sparkSession = SparkSession.builder().appName("LogsHandler").master("local[*]").getOrCreate();
        Dataset<Row> dataset = sparkSession.read().parquet("file:///Users/admin/fsdownload/logs/part-00000-14718313-4d0c-43d7-aecb-6902bf3a0cbc-c000.snappy.parquet");
        dataset.createOrReplaceTempView("traffic_statistic");
        Dataset<Row> result = dataset.sqlContext().sql("select ip,sum(flow) as total_flow from traffic_statistic where flow is not null and ip is not null group by ip");
        result.na().drop("all", new String[]{"total_flow"});//参考https://blog.csdn.net/qq_39570355/article/details/117188897

        result.foreachPartition(item -> {
            Connection connection = MysqlUtils.getConnection();
            try {
                PreparedStatement preparedStatement = connection.prepareStatement("insert into traffic_statistic(ip,flow) values (?,?)");
                while (item.hasNext()) {
                    Row next = item.next();
                    String ip = next.getString(0);
                    Double totalFlow = next.getDouble(1);
                    preparedStatement.setString(1, ip);
                    preparedStatement.setDouble(2, totalFlow == null?0:totalFlow);
                    preparedStatement.addBatch();
                }
                preparedStatement.executeBatch();
                connection.commit();
            }catch (Exception e){
                e.printStackTrace();

            }finally {
                MysqlUtils.release(connection);
            }
        });

    }
}

你可能感兴趣的:(hadoop,hadoop,spark)