pyspark sparksession_PySpark 处理数据和数据建模

安装相关包

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, when, count, countDistinct
from pyspark.sql.types import IntegerType,StringType
from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
spark = SparkSession.builder
        .config("spark.some.config.option", "some-value") 
        .config('spark.debug.maxToStringFields', '50') 
        .appName("Python Spark SQL Hive integration example")
        .enableHiveSupport()
        .getOrCreate()
sc = spark.sparkContext

1.读入数据

读入数据库中的数据X

data = spark.sql('''select * from db_so_default_tenant.entity_clueinfo
                    where custom_username not like '%测试%' 
                 ''')
# 时间部分的code报错,尚未修改
#                           and FROM_UNIXTIME(custom_create_time,'%Y-%m-%d') between date_format(date_sub(current_date,365), '%Y-%m-01') 
#                           and date_format(date_sub(current_date, 15), '%Y-%m-%d')
## Let's have a look at the data type
data.printSchema()

# 保留部分列:

keep_var_lst=['custom_clue_id', 'custom_create_time', 'custom_post_time', 'custom_username', 'custom_sex', 'custom_mobile', 
              'custom_mobile_area', 'custom_approach_id', 'custom_channel_id', 'custom_product_id', 'custom_pattern_id','custom_media_id',
              'custom_ctype_id', 'custom_activity_id','custom_detail','custom_province_id','custom_city_id','custom_district_id',
              'custom_utm_source','custom_utm_content','custom_utm_medium', 'custom_utm_campaign', 'custom_resource','custom_detail', 
              'custom_dealer_id', 'custom_area_id','custom_two_area_id'
               ]

data = data.select(keep_var_lst)

读入数据库中的Y

# 读入数据,查看数据结构

lead_feedback = spark.sql("select * from db_so_default_tenant.entity_clueinfosync")
lead_feedback.printSchema()

# 仅保留部分列

keep_var_lst2 = ['custom_clue_id', 'custom_verify_status', 'custom_sync_time']
lead_feedback = lead_feedback.select(keep_var_lst2)

# print((lead_feedback.count(), len(lead_feedback.columns)))
## (1577626, 3)

join表,得到包含X和Y的基础表

# data表append lead flag需要的字段
df = data.join(lead_feedback, on=['custom_clue_id'], how='left')

# print((df.count(), len(df.columns)))
## (1466832, 29)
# (1560986, 29)

2. 数据整合

定义Y值

# 利用pyspark.sql.functions中的when进行数据重塑

df = df.withColumn('label',when(df['custom_verify_status']==2,1).otherwise(0))

日期数据的处理

# 时间戳转换为日期

#注册临时表供SQL查询使用
df.createOrReplaceTempView("temp")
# newDF = spark.sql("select *, to_date('create_time', 'dim_month_id'), to_date('create_time', 'dim_day_id')  from df_sql ")
newDF = spark.sql("""select *, 
                  from_unixtime(custom_create_time, 'yyyy-MM')as dim_month_id,
                  from_unixtime(custom_create_time, 'yyyy-MM-dd')as dim_day_id,
                  from_unixtime(custom_create_time, 'yyyy-MM-dd HH:mm:ss')as create_time_new,
                  from_unixtime(custom_post_time, 'yyyy-MM-dd HH:mm:ss')as post_time_new
                  from temp """)
# 提取相应日期字段

#注册临时表供SQL查询使用
newDF.createOrReplaceTempView("temp")
# newDF = spark.sql("select *, to_date('create_time', 'dim_month_id'), to_date('create_time', 'dim_day_id')  from df_sql ")
newDF = spark.sql("""select *, 
                  month(create_time_new) as create_monthofyear,
                  FLOOR((day(create_time_new)-1)/7)+1 as create_weekofmonth,
                  dayofweek(create_time_new) as create_dayofweek,
                  weekofyear(create_time_new) as create_weekofyear,
                  hour(create_time_new) as create_hourofday,
                  floor(hour(create_time_new)/2) as create_hourofday2,
                  case when hour(create_time_new) between 8  and 11 then 'a.8-11'
                       when hour(create_time_new) =12               then 'b.12'
                       when hour(create_time_new) between 13 and 17 then 'c.13-17'
                       when hour(create_time_new) between 18 and 19 then 'd.18-19'
                       when hour(create_time_new) between 20 and 23 then 'e.20-23'
                       when hour(create_time_new) =0                then 'f.0'
                       when hour(create_time_new) between 1  and 2  then 'g.1-2'
                       when hour(create_time_new) =3                then 'h.3'
                       when hour(create_time_new) between 4 and 5   then 'j.4-5'
                       when hour(create_time_new) between 6 and 7   then 'k.6-7'
                       end as create_hour_flag,
                       
                  month(post_time_new) as post_monthofyear,
                  FLOOR((day(post_time_new)-1)/7)+1 as post_weekofmonth,
                  dayofweek(post_time_new) as post_dayofweek,
                  weekofyear(post_time_new) as post_weekofyear,
                  hour(post_time_new) as post_hourofday,
                  floor(hour(post_time_new)/2) as post_hourofday2,
                  case when hour(post_time_new) between 9  and 11 then 'a.9-11'
                       when hour(post_time_new) =12 then 'b.12'
                       when hour(post_time_new) between 13 and 19 then 'c.13-19'
                       when hour(post_time_new) =20 then 'd.20'
                       when hour(post_time_new) between 21 and 23 then 'e.21-23'
                       when hour(post_time_new) between 0  and 2  then 'f.0-2'
                       when hour(post_time_new) =3  then 'g.3'
                       when hour(post_time_new) between 4  and 8  then 'h.4-8'
                    end as post_hour_flag

                  from temp""")

1.创建简单flag,判断是否为null值,返回0,1

df2 = newDF
# 创建简单flag,如果为null值则为0,否则为1
def func_var_flag(var):
    if var == None or var == 0 or var == '' or var == '0':
        return 0
    else:
        return 1
    
func_var_flag_udf = udf(func_var_flag, IntegerType())
unknown_flag=['custom_username','custom_mobile_area','custom_approach_id','custom_channel_id','custom_product_id','custom_pattern_id',
              'custom_media_id','custom_ctype_id','custom_activity_id','custom_utm_source', 'custom_utm_content','custom_utm_medium',
              'custom_utm_campaign','custom_province_id', 'custom_city_id','custom_district_id','custom_dealer_id',
              'custom_area_id','custom_two_area_id','custom_resource','custom_detail'
             ]
for column in unknown_flag:
      df2=df2.withColumn(column + '_flag',  func_var_flag_udf(df2[column]))
        
        
# df2.limit(2).toPandas()
# df2.groupBy('mobile_area', 'mobile_area_flag').count().sort("count",ascending=False).show(4)

2.创建简单flag,是否为null值,是返回'Unk’,否则返回本身的结果

对于数值型的数据未做处理

# 创建简单flag,如果为null值则返回unk,否则返回其本身,字符型数据的处理,数值型呢?????
def func_var_grp_flag(var):
    if var == None or var == '':
        return 'Unk'
    else:
        return var
    
func_var_grp_udf = udf(func_var_grp_flag, StringType())
unknown_grp_flag=['custom_sex','custom_utm_medium']
for column in unknown_grp_flag:
      df2=df2.withColumn(column + '_grp',  func_var_grp_udf(df2[column]))
        
# df2.limit(2).toPandas()
# df2.dtypes
len(df2.columns)

3.字符串格式的case when,使用sql

#注册临时表供SQL查询使用
df2.createOrReplaceTempView("temp")
# newDF = spark.sql("select *, to_date('create_time', 'dim_month_id'), to_date('create_time', 'dim_day_id')  from df_sql ")
df3 = spark.sql("""select *, 
                   
                    CHAR_LENGTH(trim(custom_username)) as name_len,
                    case when CHAR_LENGTH(custom_username) = 1 then 'len=1'
                         when custom_username in ('400用户','询价客户','客户','团购用户','微聊客户','网友','报价用户','匿名用户'
                                           ,'汽车之家用户','车主','佚名',
                                           '爱卡用户','询价用户','17汽车来电客户','团购客户','匿名','意向客户') then custom_username
                         when custom_username like '%先生%' or custom_username like '%女士%' then 'x Mr/Mrs'
                         when SUBSTR(trim(custom_mobile),1,1) ="1" and CHAR_LENGTH(trim(custom_username))=11 then 'phone_num'
                         when substr(custom_username,1,1) in ('0','1','2','3','4','5','6','7','8','9') then 'numbers'
                         when CHAR_LENGTH(custom_username) > 3 then 'len>3'
                         else 'Normal'
                    end as name_flag2,
 
                    CHAR_LENGTH(trim(custom_mobile)) as mobile_len,
                    case when SUBSTR(trim(custom_mobile),1,1) =0 then 'fixed-line telephone'
                         when SUBSTR(trim(custom_mobile),1,1) =1 and CHAR_LENGTH(trim(custom_mobile)) =11 then 'mobile phone'
                         else 'No-valid'
                    end as tel_flag,
                    case when SUBSTR(trim(custom_mobile),1,1) =1 and CHAR_LENGTH(custom_mobile)=11 then SUBSTR(trim(custom_mobile),1,2)
                    end as tel_head2,
                    case when SUBSTR(trim(custom_mobile),1,1) =1 and CHAR_LENGTH(custom_mobile)=11 then SUBSTR(trim(custom_mobile),1,3)
                    end as tel_head3,
                    case when CHAR_LENGTH(custom_mobile)<>11 then 'Not-Phone'
                         when SUBSTR(trim(custom_mobile),1,3) in ('186','138','139','135','136','137','159','158','150','151',
                                                         '187','182','189','152','188','176','185','180','183','133',
                                                         '181','177','131','130','132','156','134','153','155','173',
                                                         '157','199','178','175','166','184','198','147','191','170','171'
                                                         ) then 'valid'
                         else 'No-Valid' 
                    end as tel_head3_grp,      
                     case when custom_mobile_area is null or custom_mobile_area="" then 'Unk'
                          when custom_mobile_area in ('海口市','大连市','昆明市','吉林市','江门市','西宁市','珠海市','呼和浩特市','张家口市') 
                               then 'level1'
                          when custom_mobile_area in ('金华市','赣州市','湖州市','徐州市','盐城市') then 'level2'
                          when custom_mobile_area in ('沈阳市','成都市') then 'level3'
                          when custom_mobile_area in ('杭州市','南京市','宜春市','吉安市') then 'level4' 
                          else 'Others'
                      end as mobile_area_grp,  
                      
                    case when custom_channel_id in ('73','72','10070','62','10063','61','10012','10061','65','60','10072','76',
                                            '10062','10071','63','10073','36','77') then custom_channel_id
                         else 'Others' 
                    end as channel_grp,
                    case when custom_media_id in ('4f15069347ea4') then 'level1'
                         when custom_media_id in ('4f15069348034') then 'level2'
                         when custom_media_id in ('5c7397fa8c5f3') then 'level3'
                         when custom_media_id in ('5aa8e618a1915','58107fdf18a64') then 'level4'
                         when custom_media_id in ('588176b5dc052','4f150a09d9a7d','541994c0e4126','54068f14cde9b',
                                                                '5a308c5df0537',
                                                                '54052681387a5',
                                                                '54068f14cde9h',
                                                                '5c6d2672f1f95',
                                                                '57d2a59bc8dbb',
                                                                '4f15053feac73',
                                                                '5c233d3561514',
                                                                '4f150693481c2',
                                                                '4f15069348647',
                                                                '4f150a09db456',
                                                                '4f150a09d608c') then 'level5'
                         when custom_media_id in ('0') then 'Unk'
                         else 'Others'
                    end as media_grp,
                    
                    case when custom_detail is null or custom_detail= "" then NULL
                         when custom_detail like '%询价%'   then 'Inquire'
                         when custom_detail like '%经销商%' then 'Retail'
                         when custom_detail like '%试驾%'   then 'Trial run'
                         when custom_detail like '2.0L %' or custom_detail like '2.5L %' then 'car_type'
                         when custom_detail like '%通话%'   then 'comment6'

                         when custom_detail like '%失败%'   then 'comment2'
                         when custom_detail like '%成功%'   then 'comment1'
                         when custom_detail like '%无效%'   then 'comment3'
                         when custom_detail like '%黑名单%' then 'comment4'
                         when custom_detail like '%姓名%'   then 'comment5'
                         end as comment_type,

                     case when custom_province_id in ('150000','460000','630000','530000','620000','520000','650000','24') then 'level1'
                          when custom_province_id in ('440000','610000','31','220000','640000') then 'level2'
                          when custom_province_id in ('130000','430000','370000','25','410000','210000','340000') then 'level3'
                          when custom_province_id in ('420000','350000','230000') then 'level4'
                          when custom_province_id in ('320000','450000','510000','360000','140000','330000','2') then 'level5'
                          end as custom_province_grp,
                          
                    case when custom_area_id in ('215','499') then 'South'
                         when custom_area_id in ('497')       then 'North'
                         when custom_area_id in ('500')       then 'East2'
                         when custom_area_id in ('20004')     then 'East1'
                         when custom_area_id in ('221','501') then 'North-East'
                         when custom_area_id in ('502')       then 'West'
                         end as area_grp

                  from temp """)
# df3.dtypes
len(df3.columns)
#  84

删除一些不需要的列

# 删除一些不需要的列
drop_list1 = ['custom_create_time','custom_post_time', 'create_time_new','post_time_new',
              'custom_verify_status', 'custom_sync_time',
              'custom_username','custom_mobile','custom_mobile_area','custom_media_id',
              'custom_utm_source','custom_utm_content','custom_utm_medium','custom_utm_campaign','custom_detail'
            ]
df4 = df3.select([column for column in df3.columns if column not in drop_list1])

len(df4.columns)

删除一些ID字段

# List of variables to drop - only independent variables should be left in final dataset
drop_attrs = [ "custom_clue_id", "dim_month_id","dim_day_id"]
df4 = df4.select([column for column in df4.columns if column not in drop_attrs])
# df4.select('resource_flag').distinct().show()
# df4.dtypes
len(df4.columns)

判断是否有唯一值的无关列,并进行删除

# 运行时间长
# Check if there are categorical vars with 25+ levels
one_value_flag=[]
for column in df4.columns:
    if df4.select(column).distinct().count()==1:
        one_value_flag.append(column)
one_value_flag
df4=df4.drop(*one_value_flag)
len(df4.columns)

数值转换为字符串格式

# 数值转为字符,有一些列读进来的时候转为了数值型
df5=df4
int_to_string_list=['custom_approach_id','custom_channel_id','custom_product_id','custom_pattern_id','custom_ctype_id',
                    'custom_activity_id','custom_province_id','custom_city_id','custom_district_id',
                    'custom_dealer_id','custom_area_id','custom_two_area_id'
                    ]
for col in int_to_string_list:
    df5 = df5.withColumn(col, df5[col].cast(StringType()))
    
# 单个列测试    
# df5 = df4.withColumn('approach_id', df4['approach_id'].cast(StringType()))
# df5.dtypes
numeric_cols = [x[0] for x in df5.dtypes if (x[1] != 'string')& (x[0] != 'label') ]
numeric_cols
# 字符串,其中'Attrition'是因变量
string_cols = [x[0] for x in df5.dtypes if (x[1] == 'string') ]
string_cols

字符串填充缺失值

# 当字符串中包含null值时,onehot编码会报错
for col in string_cols:
    df5 = df5.na.fill(col, 'EMPTY')
    df5 = df5.na.replace('', 'EMPTY',col)

判断每一个分类列,其分类是否大于25

方便之后进行管道处理,分类大于25的只进行stringindex转换,小于25的进行onehot变换

If any column has > 25 categories, add that column to drop list (line 24) or convert to continious variable if possible

# 运行时间长
# Check if there are categorical vars with 25+ levels
string_more_than32=[]
string_more_than25=[]
string_less_than25=[]

for column in string_cols:
    if df5.select(column).distinct().count()>=32:
        string_more_than32.append(column)
    if df5.select(column).distinct().count()>=25:
        string_more_than25.append(column)
    else:
        string_less_than25.append(column)
        
# df_long_factors = df5.select([when(countDistinct(column) >=25, 'T').otherwise('F').alias(column) for column in string_cols]) 
# df5.select('custom_sex').distinct().count()

string_more_than32
# df5.select('custom_district_id').distinct().count() # 166

删除取值大于32分类的分类变量

#删除在drop_attrs中的列
df5 = df5.drop(*string_more_than32)
len(df5.columns)
string_more_than25
string_25_than32 = list(set(string_more_than25).difference(set(string_more_than32)))
string_25_than32
string_less_than25
string_cols = [x[0] for x in df5.dtypes if (x[1] == 'string') ]
string_cols

3、管道处理

# 1. Encode the categorical data
main_stages = []
for col in string_cols:
    indexer = StringIndexer(inputCol = col, outputCol = col + 'Index', handleInvalid="keep")
    main_stages += [indexer]
    
# ?StringIndexer
# 2. OneHotEncoder for string_less_than25
for col in string_less_than25:
    encoder = OneHotEncoderEstimator(inputCols = [col + 'Index'], outputCols = [col + 'Vec'])
    main_stages += [encoder]
# 1.Variables numericas
assemblerInputs = numeric_cols
# 2.Index the label feature
assemblerInputs = assemblerInputs + [col + 'Index' for col in string_25_than32]

# 3.Add continuous variable -- one hot encoding
assemblerInputs = assemblerInputs + [col + 'Vec' for col in string_less_than25]

# len(df5.columns)# 65 Y
# len(assemblerInputs) # 64
# 4.Assemble the steps.pass all the steps in the VectorAssembler
# 导入VerctorAssembler 将多个列合并成向量列的特征转换器,即将表中各列用一个类似list表示,输出预测列为单独一列。
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol='features')
main_stages += [assembler]
# 5.Create a Pipeline.Now that all the steps are ready, you push the data to the pipeline
# 花费时间比较长
from pyspark.ml import Pipeline
pipeline = Pipeline(stages = main_stages)
pipelineModel = pipeline.fit(df5)
df6 = pipelineModel.transform(df5)

4、建立模型

划分数据集

# 创建新的只有label和features的表
# dfi = data_features.select(['label', 'features'])
dfi = df6.select(['label', 'features'])


# 将数据集分为训练集和测试集
train, test = dfi.randomSplit([0.7,0.3], 100)
# train,test,validation = dfi.randomSplit([0.6,0.2,0.2],seed=2020)

# 运行时间超长
# print("Training Dataset Count: " + str(train.count()))
# print("Test Dataset Count: " + str(test.count()))

# Training Dataset Count: 1249630
# Test Dataset Count: 311356

Random Forest Classifier

# 模型配置
rf = RandomForestClassifier( labelCol='label', 
                             featuresCol='features', 
                             numTrees=100, 
                             maxBins=32
                            )

# 训练模型
# Fit the data to the model
rfModel = rf.fit(train)


# 用 transform() 方法在测试集上做预测
predictions = rfModel.transform(test)

#选择预测结果中字段进行查看
predictions.select( 'label', 'rawPrediction', 'prediction', 'probability')
           .orderBy('probability', ascending=False)
           .show(n=10, truncate=30) 

+-----+------------------------------+----------+------------------------------+
|label| rawPrediction|prediction| probability|
+-----+------------------------------+----------+------------------------------+
| 0|[79.15890827146472,20.84109...| 0.0|[0.7915890827146475,0.20841...|
| 0|[79.10923525773862,20.89076...| 0.0|[0.7910923525773864,0.20890...|
| 0|[78.98945518105177,21.01054...| 0.0|[0.7898945518105179,0.21010...|
| 0|[78.9282993850366,21.071700...| 0.0|[0.7892829938503662,0.21071...|
| 0|[78.91212774787148,21.08787...| 0.0|[0.7891212774787151,0.21087...|
| 0|[78.89054837885494,21.10945...| 0.0|[0.7889054837885496,0.21109...|
| 0|[78.89054837885494,21.10945...| 0.0|[0.7889054837885496,0.21109...|
| 0|[78.89054837885494,21.10945...| 0.0|[0.7889054837885496,0.21109...|
| 0|[78.89054837885494,21.10945...| 0.0|[0.7889054837885496,0.21109...|
| 0|[78.89054837885494,21.10945...| 0.0|[0.7889054837885496,0.21109...|
+-----+------------------------------+----------+------------------------------+

#检验模型效果
evaluator = BinaryClassificationEvaluator() print("Test Area Under ROC: " + str(evaluator.evaluate(predictions, {evaluator.metricName: "areaUnderROC"})))  
# Test Area Under ROC: 0.6160155402990332

保存模型

# import sys, os
# os.getcwd() 
rfModel.write().overwrite().save('Model test/rfModel') 

加载模型

from pyspark.ml.classification import RandomForestClassificationModel 
model_1 = RandomForestClassificationModel.load('Model test/rfModel') 

Gradient-Boosted Tree Classifier

# 模型配置 train a GBTC model
gbt = GBTClassifier(maxIter=10)
# 训练模型
# Fit the data to the model
gbtModel = gbt.fit(train)
# 用 transform() 方法在测试集上做预测
predictions = gbtModel.transform(test)

#选择预测结果中字段进行查看
predictions.select( 'label', 'rawPrediction', 'prediction', 'probability').show(10)

+-----+--------------------+----------+--------------------+
|label| rawPrediction|prediction| probability|
+-----+--------------------+----------+--------------------+
| 0|[-0.0582178194283...| 1.0|[0.47092393217850...|
| 0|[-0.0667980984304...| 1.0|[0.46665053764714...|
| 0|[-0.0560469563372...| 1.0|[0.47200582803120...|
| 0|[0.04211971652931...| 0.0|[0.52104741320470...|
| 0|[0.08544882017875...| 0.0|[0.54262072878469...|
| 0|[-0.0728647167488...| 1.0|[0.46363198136231...|
| 0|[-0.0142166646760...| 1.0|[0.49289214652005...|
| 0|[0.08754857661758...| 0.0|[0.54366279043135...|
| 0|[-0.0676538770780...| 1.0|[0.46622457631215...|
| 0|[-0.0713656699888...| 1.0|[0.46437762010753...|
+-----+--------------------+----------+--------------------+

#模型检验
evaluator = BinaryClassificationEvaluator()
print("Test Area Under ROC: " + str(evaluator.evaluate(predictions, {evaluator.metricName: "areaUnderROC"})))

# 保存Gradient-Boosted 模型

gbtModel.write().overwrite().save('Model test/gbtModel')

你可能感兴趣的:(pyspark,sparksession)