spark递归行转列,list转dataset

SparkSession spark = SparkSession
                .builder()
                .master("local")
                .appName("JavaFPGrowthExample")
                .getOrCreate();
        Dataset csv = spark.read().option("header","true").csv("E:\\idea\\taskmanage\\task-spark\\data\\lineToColumn\\ods_crb_ems_t_exp_category_d1.csv");
        Dataset select = csv.select("`ods_crb_ems_t_exp_category_d1.exp_category_id`", "`ods_crb_ems_t_exp_category_d1.exp_category_name`", "`ods_crb_ems_t_exp_category_d1.parent_id`","`ods_crb_ems_t_exp_category_d1.parent_name`")
                .withColumnRenamed("ods_crb_ems_t_exp_category_d1.exp_category_id", "id")
                .withColumnRenamed("ods_crb_ems_t_exp_category_d1.exp_category_name", "name")
                .withColumnRenamed("ods_crb_ems_t_exp_category_d1.parent_id", "pid")
                .withColumnRenamed("ods_crb_ems_t_exp_category_d1.parent_name","pName");
        select.show(20);
        List rows = select.collectAsList();
        List> all = new ArrayList<>();
        for (Row row : rows) {
            int id1 = row.fieldIndex(idColumn);
            String id = row.getString(id1);
            int nameIndex = row.fieldIndex(nameColumn);
            String name = row.getString(nameIndex);
            List pids = new ArrayList<>();
            pids.add(name);
            pids.add(id);

            getPids(rows, id, pids);
            Collections.reverse(pids);
            pids.remove(0);
            pids.remove(0);
            all.add(pids);
        }

        List> collect = all.stream().sorted(Comparator.comparing(v -> v.get(0))).collect(Collectors.toList());
        //寻找最大列
        Integer max = 0;
        for (List strings : collect) {
            if(strings.size() > max){
                max = strings.size();
            }
        }
        ArrayList list = new ArrayList<>();
        for (int i = 0; i < max; i++){
            for (List strings : collect) {
                Integer size = (max - strings.size())/2;
                if(size > 0){
                    for(int j = 0; j < size; j++){
                        String code = strings.get(strings.size() - 2);
                        String name = strings.get(strings.size() - 1);
                        strings.add(code);
                        strings.add(name);
                    }
                }
                Row row = RowFactory.create(strings.toArray());
                list.add(row);
            }
        }
//        list.stream().forEach(v -> {
//            System.out.println(v);
//        });



        JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
        JavaRDD parallelize = sc.parallelize(list);
        ArrayList fields = new ArrayList();
        StructField  field = null;
        field = DataTypes.createStructField("c0", DataTypes.StringType, true);
        fields.add(field);
        field = DataTypes.createStructField("c1", DataTypes.StringType, true);
        fields.add(field);
        field = DataTypes.createStructField("c2", DataTypes.StringType, true);
        fields.add(field);
        field = DataTypes.createStructField("c3", DataTypes.StringType, true);
        fields.add(field);
        field = DataTypes.createStructField("c4", DataTypes.StringType, true);
        fields.add(field);
        field = DataTypes.createStructField("c5", DataTypes.StringType, true);
        fields.add(field);
        StructType schema = DataTypes.createStructType(fields);
        parallelize.foreach(v -> {
            System.out.println(v);
        });
        Dataset dataFrame = spark.createDataFrame(parallelize, schema);
        dataFrame.show();
public static void getPids(List rows, String id, List pids){
    for (Row row : rows) {
        int i = row.fieldIndex(idColumn);
        String str = row.getString(i);
        if(id.equals(str)){
            int i1 = row.fieldIndex(pidColumn);
            int i2 = row.fieldIndex(pName);
            String string = row.getString(i1);
            pids.add(row.getString(i2));
            pids.add(string);

            getPids(rows, string, pids);
            return;
        }
    }
}

public static void getPid(Dataset dataset, String id, List pids){
    dataset.foreach(v -> {
        int i = v.fieldIndex(pidColumn);
        String str = v.getString(i);
        if(id.equals(str)){
            int i1 = v.fieldIndex(pidColumn);
            String string = v.getString(i1);
            pids.add(string);
            getPid(dataset, string, pids);
            return;
        }
    });
}

你可能感兴趣的:(spark)