pyspark dataframe将一行分成多行并标记序号(index)


gid score
a1 90 80 79 80
a2 79 89 45 60
a3 57 56 89 75
from pyspark.sql.functions import udf, col
from pyspark.sql.types import MapType, IntegerType, StringType

def udf_array_to_map(array):
    if array is None:
        return array
    return dict((i, v) for i, v in enumerate(array))

# col(): returns a column based on the given column name
# MapType: 表示包括一组key-value的值.通过keyType表示key数据的类型,通过valueType表示value数据的类型.
#          最后一个参数指明mapType重点值是否有null值
def generate_idx_for_df(df, id_name, col_name, col_schema):
    generate_idx_for_df, explodes rows with array as a column into a new row for each
    element in the array, with 'INTEGER_IDX' indicating its index in the original array.
    :param df: dataframe with array columns
    :param id_name: the id field of df
    :param col_name: the col of df to explode
    :param col_schema: the schema of each element in col_name array
    :return: new df with exploded rows.
    idx_udf = udf(lambda x: udf_array_to_map(x), MapType(IntegerType(), col_schema, True))
    return df.withColumn('idx_columns', idx_udf(col(col_name))) \
            .select(id_name, explode('idx_columns').alias('INTEGER_IDX', 'col'))



gid s idx_columns
a1 [90, 80, 79, 80] {0=90, 1=80, 2=79...
a2 [79, 89, 45, 60] {0=79, 1=89, 2=45...
a3 [57, 56, 89, 75] {0=57, 1=56, 2=89...

org.apache.spark.sql.AnalysisException: cannot resolve 'explode(idx_columns)' due to data type mismatch: input to function explode should be array or map type, not StringType;


gid s idx_columns
a1 [90, 80, 79, 80] Map(0 -> 90, 1 ->...
a2 [79, 89, 45, 60] Map(0 -> 79, 1 ->...
a3 [57, 56, 89, 75] Map(0 -> 57, 1 ->...
from pyspark.sql.functions import split, explode
df_split = df.withColumn("s", split(df['score'], " ")).select('gid', 's')
col_schema = StringType()
df_index = generate_idx_for_df(df_split, 'gid', 's', col_schema)

最后分割完成后的结果如下所示 :

a1 0 90
a1 1 80
a1 2 79
a1 3 80
a2 0 79
a2 1 89
a2 2 45
a2 3 60
a3 0 57
a3 1 56
a3 2 89
a3 3 75


