pyspark中使用categoricalFeaturesInfo来标记分类型变量

以使用pyspark的随机森林作为例子:

#! /usr/bin/python3
#-*-coding:utf-8-*-

from pyspark import SparkContext,SparkConf
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
from pyspark.mllib.tree import RandomForest
from pyspark.sql import SQLContext

# Configuration if you use spark-submit 
conf = SparkConf().setAppName("Test Application")
conf = conf.setMaster("local[10]")
sc = SparkContext(conf=conf)
sqlCtx = SQLContext(sc)

def create_label_point(line):
    line=line.strip().split(',')
    return LabeledPoint(int(line[-1]), [float(x) for x in line[:-1]])


train=sc.textFile("file:///home/hujianqiu/20eg/BLOGGER/kohkiloyeh_train").map(create_label_point)
test=sc.textFile("file:///home/hujianqiu/20eg/BLOGGER/kohkiloyeh_test").map(create_label_point)



#print("rf start")
model = RandomForest.trainClassifier(train, numClasses=2,
                            categoricalFeaturesInfo={
    0:3,1:3,2:5,3:2,4:2},
                            numTrees=50,
                            featureSubsetStrategy="auto",
                            impurity="gini",
                            maxDepth=5,
                            maxBins=100,
                            seed=12345)

predictions = model.predict(test.map(lambda x: x.features))
labels_and_preds = test.map(lambda p: p.label).zip(predictions)

# Confusion Matrix
testErr_11 = labels_and_preds.filter(lambda (v, p): (v, p) == (1, 1)).count()
testErr_10 = labels_and_preds.filter(lambda (v, p): (v, p) == (1, 0)).count()
testErr_01 = labels_and_preds.filter(lambda (v, p): (v, p) == (0, 1)).count()
testErr_00 = labels_and_preds.filter(lambda (v, p): (v, p) == (0, 0)).count()


accuracy=(float(testErr_11)+float(testErr_00))/(float(testErr_11)+float(testErr_10)+float(testErr_01)+float(testErr_00))
recall=float(testErr_11)/(float(testErr_11)+float(testErr_10))
precision=float(testErr_11)/(float(testErr_11)+float(testErr_01))
F1_measure=2*precision*recall/(precision+recall)

with open('/home/hujianqiu/20eg/BLOGGER/result.txt','w') as f:
    f.write('testErr_11:\t%d\n'%testErr_11)
    f.write('testErr_10:\t%d\n'%testErr_10)
    f.write('testErr_01:\t%d\n'%testErr_01)
    f.write('testErr_00:\t%d\n'%testErr_00)
    f.write('accuracy:\t%f\n'%accuracy)
    f.write('recall:\t\t%f\n'%recall)
    f.write('precision:\t%f\n'%precision)
    f.write('F1_measure:\t%f\n'%F1_measure)
    #f.write(model.toDebugString())
# Save and load model
# model.save(sc, "/home/air/hjq/proofread-randomforesst/myRandomForestClassificationModel")
# sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel")

其中可以看到RandomForest.trainClassifier中有一个参数为categoricalFeaturesInfo,这其实是一个字典
如果没有分类型变量,需要设categoricalFeaturesInfo={}
如果有分类型变量, 则该字典需要有值:
- key为自变量的位置,value为自变量的分类数;
- 第一个自变量为0,第二个自变量为1,以此类推;
- 数据需要事先码值,编码为0,1,…,N-1,N为该自变量的分类数,要与categoricalFeaturesInfo对应好

kohkiloyeh_test
0,2,1,0,0,0
0,2,2,0,0,0
1,1,3,0,0,0
1,0,0,0,1,0
0,0,0,0,0,0
0,2,2,0,0,1
0,2,2,0,0,0
0,0,2,0,0,0
1,2,1,0,0,0
0,2,2,0,0,0
1,2,3,0,0,0
2,2,0,0,1,1
0,2,3,0,0,0
0,2,2,0,0,0
2,0,1,1,1,0
1,2,4,0,0,1
2,2,4,0,0,1
1,0,0,0,0,1
0,2,1,0,1,0
1,2,0,0,1,0
1,1,4,0,1,0
1,1,1,1,0,1
1,0,4,0,0,1
0,2,2,0,0,1
1,0,3,0,0,1
1,1,3,0,0,0
2,0,1,0,1,1
1,0,3,0,0,1
1,1,3,0,0,0
1,1,3,0,0,0
0,2,2,0,0,0
1,0,0,0,1,0
1,0,0,0,0,0
2,0,3,1,0,1
0,2,2,0,0,0
1,2,2,0,0,0
0,0,1,0,0,1
1,2,2,0,0,0
1,2,3,0,0,0
1,0,2,1,0,1
2,0,1,1,1,0
0,0,2,0,0,0
0,2,1,1,1,0
1,2,4,0,0,1
0,0,2,0,0,0
2,0,2,0,0,0
1,1,1,1,0,1
1,2,1,1,1,0
1,2,3,1,0,0
1,2,0,0,1,0
2,1,1,0,1,1
1,2,1,0,0,0

kohkiloyeh_test
1,1,0,0,0,0
0,2,2,0,0,0
1,1,3,0,0,0
0,2,2,0,0,0
0,0,2,0,1,0
0,0,2,0,1,1
1,2,3,0,1,0
2,0,3,1,0,1
1,2,1,0,0,0
1,2,2,0,0,0
0,0,0,0,0,1
0,0,2,1,0,1
0,0,2,0,0,0
0,2,1,1,1,0
0,0,2,0,0,0
2,0,2,0,0,0
1,2,1,1,1,0
1,2,3,1,0,0
0,2,2,0,0,0
1,2,0,0,1,0
2,1,1,0,1,1
1,2,1,0,0,0
0,2,2,0,0,0
0,2,2,0,0,0
1,1,0,0,0,0
0,2,2,0,0,0
0,0,2,0,1,0
0,0,2,0,1,1
1,2,3,0,1,0
0,2,1,0,0,1
1,2,1,0,0,0
0,0,2,0,0,0
1,2,2,0,0,0
2,2,0,0,1,1
0,2,3,0,0,0
0,2,2,0,0,0
2,2,4,0,0,1
1,0,0,0,0,1
0,2,1,0,1,0
1,2,0,0,1,0
1,1,1,0,1,0
1,0,4,0,0,1
0,2,2,0,0,1
0,2,2,0,0,0
1,0,3,0,0,1
1,1,1,0,0,0
2,0,1,0,1,1
1,0,3,0,0,1

你可能感兴趣的:(spark)