pyspark 对多列类别特征编码 Pipeline(stages=[ StringIndexer

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, StringIndexerModel
from pyspark.sql import SparkSession
import safe_config

spark_app_name = 'lgb_hive_data'
spark = SparkSession.builder \
    .config('spark.executor.memory', '13g') \
    .config('spark.executor.cores', '3') \
    .config('spark.driver.memory', '20g') \
    .config('spark.executor.instances', '70') \
    .config('spark.sql.execution.arrow.enabled', 'true') \
    .config('spark.driver.maxResultSize', '20g') \
    .config('spark.default.parallelism', '9000') \
    .config('spark.sql.sources.default', 'orc') \
    .config('spark.sql.sources.partitionOverwriteMode', 'dynamic') \
    .config('spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation', 'true') \
    .appName(spark_app_name) \
    .enableHiveSupport().getOrCreate()

df = spark.sql("""
select * from aiplatform.travel_risk_index_safe_rule where pt >= '20200801' 
""")
pipeline = Pipeline(stages=[
    StringIndexer(inputCol=c, outputCol='{}_new_col'.format(c),handleInvalid="keep")
    for c in safe_config.TEXT_CATEGORICAL_COLS
])
model = pipeline.fit(df)
indexed = model.transform(df)

index_dict = {c.name: c.metadata["ml_attr"]["vals"]
 for c in indexed.schema.fields if c.name.endswith("_new_col")}

# label encoder dict
import json
import numpy as np
def key_to_json(data):
    if data is None or isinstance(data, (bool, int, str, float)):
        return data
    if isinstance(data, (tuple, frozenset)):
        return str(data)
    if isinstance(data, np.integer):
        return int(data)
    if isinstance(data, np.float):
        return int(data)
    raise TypeError
def to_json(data):
    if data is None or isinstance(data, (bool, int, tuple, range, str, list)):
        return data
    if isinstance(data, (set, frozenset)):
        return sorted(data)
    if isinstance(data, np.float):
        return float(data)
    if isinstance(data, dict):
        return {key_to_json(key): to_json(data[key]) for key in data}
    raise TypeError
    
text_index_dict = {}
for index,value in enumerate(index_dict):
        print(index,value)
        col_values = index_dict[value]
        tmp_dict = {}
        for index_2,value_2 in enumerate(col_values):
            tmp_dict[value_2] = index_2
        text_index_dict[value] = tmp_dict
with open(f'''./index.json''', 'w') as fp:
    json.dump(to_json(text_index_dict), fp)

参考

https://stackoverflow.com/questions/45885044/getting-labels-from-stringindexer-stages-within-pipeline-in-spark-pyspark

pyspark特征工程常用方法(一)

https://blog.csdn.net/Katherine_hsr/article/details/81004708

 

 

你可能感兴趣的:(大数据spark)