Dataset 的一些 Java api 操作

文章目录

      • 一、使用 Java API 和 JavaRDD 在 Spark SQL 中向数据帧添加新列
      • 二、foreachPartition 遍历 Dataset
      • 三、Dataset 自定义 Partitioner
      • 四、Dataset 重分区并且获取分区数
      • 五、去重方法 dropDuplicates
      • 六、Dataset 转换为 List
      • 七、自定义函数 UDF
      • 八、替换函数
      • 九、na.fill用法
      • 十、if用法

一、使用 Java API 和 JavaRDD 在 Spark SQL 中向数据帧添加新列

  在应用 mapPartition 函数后创建一个新的数据框:

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class Handler implements Serializable {

    public void handler(Dataset<Row> sourceData) {
        Dataset<Row> rowDataset = sourceData
                .where("rowKey = 'abcdefg_123'")
                .selectExpr("split(rowKey, '_')[0] as id",
                        "name",
                        "time")
                .where("name = '小强'")
                .orderBy(functions.col("id").asc(), functions.col("time").desc());

        FlatMapFunction<Iterator<Row>,Row> mapPartitonstoTime = rows->
        {
            Int count = 0; // 只能在每个分区内自增,不能保证全局自增
			String startTime = "";
			String endTime = "";
			List<Row> mappedRows=new ArrayList<Row>();
            while(rows.hasNext())
            {
                count++;
                Row next = rows.next();
                String id = next.getAs("id");
                if (count == 2) {
					startTime = next.getAs("time");
					endTime = next.getAs("time");
                }
                Row mappedRow= RowFactory.create(next.getString(0), next.getString(1), next.getString(2), endTime, startTime);
                mappedRows.add(mappedRow);
            }
            return mappedRows.iterator();
        };

        JavaRDD<Row> sensorDataDoubleRDD=rowDataset.toJavaRDD().mapPartitions(mapPartitonstoTime);

        StructType oldSchema=rowDataset.schema();
        StructType newSchema =oldSchema.add("startTime",DataTypes.StringType,false)
                .add("endTime",DataTypes.StringType,false);

        System.out.println("The new schema is: ");
        newSchema.printTreeString();

        System.out.println("The old schema is: ");
        oldSchema.printTreeString();

        Dataset<Row> sensorDataDoubleDF=spark.createDataFrame(sensorDataDoubleRDD, newSchema);
        sensorDataDoubleDF.show(100, false);
    }
}

打印结果:

The new schema is: 
root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- time: string (nullable = true)

The old schema is: 
root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- time: string (nullable = true)
 |-- startTime: string (nullable = true)
 |-- endTime: string (nullable = true)

+-----------+---------+----------+----------+----------+
|id         |name     |time      |startTime |endTime   |
+-----------+---------+----------+----------+----------+
|abcdefg_123|xiaoqiang|1693462023|1693462023|1693462023|
|abcdefg_321|xiaoliu  |1693462028|1693462028|1693462028|
+-----------+---------+----------+----------+----------+

参考:
java - 使用 Java API 和 JavaRDD 在 Spark SQL 中向数据帧添加新列
java.util.Arrays$ArrayList cannot be cast to java.util.Iterator

二、foreachPartition 遍历 Dataset

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.IOException;
import java.io.Serializable;
import java.util.Iterator;

public class Handler implements Serializable {

    public void handler(Dataset<Row> sourceData) {
        JavaRDD<Row> dataRDD = rowDataset.toJavaRDD();
        dataRDD.foreachPartition(new VoidFunction<Iterator<Row>>() {
            @Override
            public void call(Iterator<Row> rowIterator) throws Exception {
                while (rowIterator.hasNext()) {
                    Row next = rowIterator.next();
                    String id = next.getAs("id");
                    if (id.equals("123")) {
                        String startTime = next.getAs("time");
                        // 其他业务逻辑
                    }
                }
            }
        });

	    // 转换为 lambda 表达式
	    dataRDD.foreachPartition((VoidFunction<Iterator<Row>>) rowIterator -> {
            while (rowIterator.hasNext()) {
                Row next = rowIterator.next();
                String id = next.getAs("id");
                if (id.equals("123")) {
                    String startTime = next.getAs("time");
                    // 其他业务逻辑
                }
            }
        });
    }
}

三、Dataset 自定义 Partitioner

参考:spark 自定义 partitioner 分区 java 版

import org.apache.commons.collections.CollectionUtils;
import org.apache.spark.Partitioner;
import org.junit.Assert;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Created by lesly.lai on 2018/7/25.
 */
public class CuxGroupPartitioner extends Partitioner {

    private int partitions;

    /**
     * map
     * 主要为了区分不同分区
     */
    private Map<Object, Integer> hashCodePartitionIndexMap = new ConcurrentHashMap<>();

    public CuxGroupPartitioner(List<Object> groupList) {
        int size = groupList.size();
        this.partitions = size;
        initMap(partitions, groupList);
    }

    private void initMap(int size, List<Object> groupList) {
        Assert.assertTrue(CollectionUtils.isNotEmpty(groupList));
        for (int i=0; i<size; i++) {
            hashCodePartitionIndexMap.put(groupList.get(i), i);
        }
    }

    @Override
    public int numPartitions() {
        return partitions;
    }

    @Override
    public int getPartition(Object key) {
        return hashCodePartitionIndexMap.get(key);
    }

    public boolean equals(Object obj) {
        if (obj instanceof CuxGroupPartitioner) {
            return ((CuxGroupPartitioner) obj).partitions == partitions;
        }
        return false;
    }
}

查看分区分布情况工具类:
(1)Scala:

import org.apache.spark.sql.{Dataset, Row}

/**
 * Created by lesly.lai on 2017/12FeeTask/25.
 */
class SparkRddTaskInfo {
  def getTask(dataSet: Dataset[Row]) {
    val size = dataSet.rdd.partitions.length
    println(s"==> partition size: $size " )
    import scala.collection.Iterator
    val showElements = (it: Iterator[Row]) => {
      val ns = it.toSeq
      import org.apache.spark.TaskContext
      val pid = TaskContext.get.partitionId
      println(s"[partition: $pid][size: ${ns.size}] ${ns.mkString(" ")}")
    }
    dataSet.foreachPartition(showElements)
  }
}

(2)Java:

import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class SparkRddTaskInfo {
    public static void getTask(Dataset<Row> dataSet) {
        int size = dataSet.rdd().partitions().length;
        System.out.println("==> partition size:" + size);

        JavaRDD<Row> dataRDD = dataSet.toJavaRDD();
        dataRDD.foreachPartition((VoidFunction<Iterator<Row>>) rowIterator -> {
            List<String> mappedRows = new ArrayList<String>();
            int count = 0;
            while (rowIterator.hasNext()) {
                Row next = rowIterator.next();
                String id = next.getAs("id");
                String partitionKey = next.getAs("partition_key");
                String name = next.getAs("name");
                mappedRows.add(id + "/" + partitionKey+ "/" + name);
            }
            int pid = TaskContext.get().partitionId();
            System.out.println("[partition: " + pid + "][size: " + mappedRows.size() + "]" + mappedRows);
        });
    }
}

调用方式:

import com.vip.spark.db.ConnectionInfos;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;

import java.util.List;
import java.util.stream.Collectors;

/**
 * Created by lesly.lai on 2018/7/23.
 */
public class SparkSimpleTestPartition {
	public static void main(String[] args) throws InterruptedException {
	
		SparkSession sparkSession = SparkSession.builder().appName("Java Spark SQL basic example").getOrCreate();
		// 原始数据集
		Dataset<Row> originSet = sparkSession.read().jdbc(ConnectionInfos.TEST_MYSQL_CONNECTION_URL, "people", ConnectionInfos.getTestUserAndPasswordProperties());
		originSet
		.selectExpr("split(rowKey, '_')[0] as id",
            "concat(split(rowKey, '_')[0],'_',split(rowKey, '_')[1]) as partition_key",
             "split(rowKey, '_')[1] as name"
		.createOrReplaceTempView("people");
		// 获取分区分布情况工具类
		SparkRddTaskInfo taskInfo = new SparkRddTaskInfo();
		Dataset<Row> groupSet = sparkSession.sql(" select partition_key from people group by partition_key");
		List<Object> groupList = groupSet.javaRDD().collect().stream().map(row -> row.getAs("partition_key")).collect(Collectors.toList());
		// 创建pairRDD 目前只有pairRdd支持自定义partitioner,所以需要先转成pairRdd
		JavaPairRDD pairRDD = originSet.javaRDD().mapToPair(row -> {
			return new Tuple2(row.getAs("partition_key"), row);
		});
		// 指定自定义partitioner
		JavaRDD javaRdd = pairRDD.partitionBy(new CuxGroupPartitioner(groupList)).map(new Function<Tuple2<String, Row>, Row>(){
			@Override
			public Row call(Tuple2<String, Row> v1) throws Exception {
				return v1._2;
			}
		});
		Dataset<Row> result = sparkSession.createDataFrame(javaRdd, originSet.schema());
		// 打印分区分布情况
		taskInfo.getTask(result);
	}
}

四、Dataset 重分区并且获取分区数

        System.out.println("1-->"+rowDataset.rdd().partitions().length);
        System.out.println("1-->"+rowDataset.rdd().getNumPartitions());
        Dataset<Row> hehe = rowDataset.coalesce(1);
        System.out.println("2-->"+hehe.rdd().partitions().length);
        System.out.println("2-->"+hehe.rdd().getNumPartitions());

运行结果:

1-->29
1-->29
2-->2
2-->2

注意:在使用 repartition() 时两次打印的结果相同:

print(rdd.getNumPartitions())
rdd.repartition(100)
print(rdd.getNumPartitions())

产生上述问题的原因有两个:
  首先 repartition() 是惰性求值操作,需要执行一个 action 操作才可以使其执行。
  其次,repartition() 操作会返回一个新的 rdd,并且新的 rdd 的分区已经修改为新的分区数,因此必须使用返回的 rdd,否则将仍在使用旧的分区。
  修改为:rdd2 = rdd.repartition(100)

参考:repartition() is not affecting RDD partition size

五、去重方法 dropDuplicates

  功能:对DF的数据进行去重,如果重复数据有多条,取第一条

# 去重API dropDuplicates,无参数是对数据进行整体去重
df.dropDuplicates().show()
# API 同样可以针对字段进行去重,如下传入age字段,表示只要年龄一样,就认为你是重复数据
df.dropDuplicates(['age','job']).show()

来自:大数据开发 | SparkSQL 如何去重重复值?

六、Dataset 转换为 List

Tuple4, String, String, String> mySQLInfo = getMySQLInfo(configFile);
Properties prop = new Properties();
prop.setProperty("user", mySQLInfo._2());
prop.setProperty("password", mySQLInfo._3());
prop.setProperty("driver", mySQLInfo._4());
Dataset df = spark.read().jdbc(mySQLInfo._1(), tableName, prop);
List collectAsList = df
    .selectExpr(typeId).dropDuplicates()
    .map((MapFunction, String>) row -> row.mkString(","), Encoders.STRING()).collectAsList();

七、自定义函数 UDF

// Dataset自定义函数:时间向上取整,半小时
spark.udf().register("timeCeil", (String field) -> {
    String[] timeSplit = field.split(":");
    // 数字字符串前补零
    DecimalFormat g1 = new DecimalFormat("00");
    String hour = timeSplit[0];
    String standard;
    // 时间向上取整:取半小时整点
    if (Integer.parseInt(timeSplit[1]) > 30) {
        hour = g1.format(Integer.parseInt(hour) + 1);
        standard = "00";
    } else {
        standard = "30";
    }
    return hour + ":" + standard + ":00";
}, DataTypes.StringType);

Dataset<Row> rowDataset = sourceData.selectExpr("Time", "timeCeil(Time) as HalfHour");

效果:
+----------+--------+--------+
|Date      |Time    |HalfHour|
+----------+--------+--------+
|2023-09-13|00:30:46|00:30:00|
|2023-09-13|00:30:51|00:30:00|
|2023-09-13|00:30:56|00:30:00|
|2023-09-13|00:31:01|01:00:00|
|2023-09-13|00:31:06|01:00:00|
|2023-09-13|00:31:11|01:00:00|

八、替换函数

Dataset<Row> rowDataset = sourceData.selectExpr("replace(split(rowKey, '_')[0], '我爱你', '点赞加个关注呗') as studentId");
// 等价于
Dataset<Row> rowDataset = sourceData.selectExpr("regexp_replace(split(rowKey, '_')[0], '我爱你', '点赞加个关注呗') as studentId");

参考:SparkSQL中常见函数

九、na.fill用法

  DF.na.fill("NULL") 是使用 Spark DataFrame API 中的 na 方法来填充数据中的缺失值。具体地,该代码将 DataFrame 中的所有缺失值(即 null 值)都填充为字符串 NULL

  对两个数据表如A,B取JOIN操作的时候,其结果往往会出现NULL值的出现。这种情况是非常不利于后续的分析与计算的,特别是当涉及到对这个数值列进行各种聚合函数计算的时候。

  Spark 为此提供了一个高级操作,就是:na.fill 的函数。其处理过程就是先构建一个 MAP,如下:val map = Map("列名1“ -> 指定数字, "列名2“ -> 指定数字, .....),然后执行 dataframe.na.fill(map),即可实现对 NULL 值的填充。

参考:
Dataframe中na.fill的用法
scala spark DF.na.fill(“NULL”) 代表啥?为啥使用后会出现数据比原来DF行数多的情况呢?

  我工作中代码是这么使用的:

        Dataset<Row> rowDataset = sourceData
                .selectExpr("split(rowKey, '_')[0] as studentId",
                        "Date",
                        "Time",
                        "get_json_object(Heheda,'$.点个赞关注一下呗') as Hehe")
                .na().fill("0.0");

十、if用法

Dataset<Row> rowDataset4 = rowDataset3.selectExpr("studentId", "Date", "if(HalfHour='23:30:00',ts+1799,ts+1800) as tsMinusHalf")
        .orderBy("studentId");

Dataset<Row> rowDataset5 = rowDataset3.join(rowDataset4, rowDataset4.col("studentId").equalTo(rowDataset3.col("studentId"))
        .and(rowDataset4.col("Date").equalTo(rowDataset3.col("Date"))), "left")
        .selectExpr("studentId", "Date", 
                "if(min_TotalPeople is null,first_TotalPeople,min_TotalPeople) as min_TotalPeople")
        .orderBy(functions.col("studentId").asc(),
                functions.col("Date").asc());

你可能感兴趣的:(spark,Dataset,Spark,SQL)