pip3 install xgboost
毒蘑菇数据集的描述参考:
https://archive.ics.uci.edu/ml/datasets/Mushroom
毒蘑菇的特征描述如下
Attribute Information:
1. cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s
2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s
3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r, pink=p,purple=u,red=e,white=w,yellow=y
4. bruises?: bruises=t,no=f
5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f, musty=m,none=n,pungent=p,spicy=s
6. gill-attachment: attached=a,descending=d,free=f,notched=n
7. gill-spacing: close=c,crowded=w,distant=d
8. gill-size: broad=b,narrow=n
9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e, white=w,yellow=y
10. stalk-shape: enlarging=e,tapering=t
11. stalk-root: bulbous=b,club=c,cup=u,equal=e, rhizomorphs=z,rooted=r,missing=?
12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s
13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s
14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y
15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y
16. veil-type: partial=p,universal=u
17. veil-color: brown=n,orange=o,white=w,yellow=y
18. ring-number: none=n,one=o,two=t
19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l, none=n,pendant=p,sheathing=s,zone=z
20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r, orange=o,purple=u,white=w,yellow=y
21. population: abundant=a,clustered=c,numerous=n, scattered=s,several=v,solitary=y
22. habitat: grasses=g,leaves=l,meadows=m,paths=p, urban=u,waste=w,woods=d
mushroom数据集原有的格式是:
p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g
第一列的对应关系是:无毒 edible=e, 有毒 poisonous=p,后面的列是特征。
显然,上述特征值都是categorical 特征,跟上面22个特征描述一一对应。对于这种categorical feature,一般都要进行onehot编码,可以借助sklearn的DictVectorizer或者自己编写onehot,然后按照xgboost的输入格式 :
category feature_id:feature_value...
例如xgboost tutorial里面内置的mushroom数据集:
1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1
由于xgboost的repo已经内置了mushroom的数据集,并且格式已经调整好,所以就直接用吧(懒~)
先把 https://github.com/dmlc/xgboost 这个repo clone下来,demo/data里面有处理好的mushroom数据集
下面跑一下交叉验证的代码
#!/usr/bin/python
import numpy as np
import xgboost as xgb
import json
### load data in do training
dtrain = xgb.DMatrix('data/agaricus.txt.train')
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
num_round = 2
print('running cross validation')
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.print_evaluation(show_stdv=True)])
print('running cross validation, disable standard deviation display')
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error',"auc","rmse"}, seed=0,
callbacks=[xgb.callback.print_evaluation(show_stdv=False),
xgb.callback.early_stop(3)])
print(json.dumps(res, indent=4))
评价的几个维度,分别是error, auc, rmse,输入如下:
[22:46:13] 6513x127 matrix with 143286 entries loaded from ../data/agaricus.txt.train
running cross validation
[0] train-error:0.0506682+0.009201 test-error:0.0557316+0.0158887
[1] train-error:0.0213034+0.00205561 test-error:0.0211884+0.00365323
running cross validation, disable standard deviation display
[0] train-auc:0.962827 train-error:0.0506682 train-rmse:0.230558 test-auc:0.960993 test-error:0.0557316 test-rmse:0.234907
Multiple eval metrics have been passed: 'test-rmse' will be used for early stopping.
Will train until test-rmse hasn't improved in 3 rounds.
[1] train-auc:0.984829 train-error:0.0213034 train-rmse:0.159816 test-auc:0.984753 test-error:0.0211884 test-rmse:0.159676
[2] train-auc:0.99763 train-error:0.0099418 train-rmse:0.111782 test-auc:0.997216 test-error:0.0099786 test-rmse:0.113232
[3] train-auc:0.998845 train-error:0.0141256 train-rmse:0.104002 test-auc:0.998575 test-error:0.0144336 test-rmse:0.105863
[4] train-auc:0.999404 train-error:0.0059878 train-rmse:0.078452 test-auc:0.999001 test-error:0.0062948 test-rmse:0.0814638
[5] train-auc:0.999571 train-error:0.0020344 train-rmse:0.0554116 test-auc:0.999236 test-error:0.0016886 test-rmse:0.0549618
[6] train-auc:0.999643 train-error:0.0012284 train-rmse:0.0442974 test-auc:0.999389 test-error:0.001228 test-rmse:0.0447266
[7] train-auc:0.999736 train-error:0.0012284 train-rmse:0.0409082 test-auc:0.999535 test-error:0.001228 test-rmse:0.0408704
[8] train-auc:0.999967 train-error:0.0009212 train-rmse:0.0325856 test-auc:0.99992 test-error:0.001228 test-rmse:0.0378632
[9] train-auc:0.999982 train-error:0.0006142 train-rmse:0.0305786 test-auc:0.999959 test-error:0.001228 test-rmse:0.0355032
{
"train-auc-mean": [
0.9628270000000001,
0.9848285999999999,
0.9976303999999999,
0.9988453999999999,
0.9994036000000002,
0.9995708000000001,
0.9996430000000001,
0.9997361999999999,
0.9999673999999998,
0.9999824
],
"train-auc-std": [
0.008448006628785306,
0.006803294807664875,
0.0009387171245907945,
0.0003538839357755705,
0.0002618454506001579,
0.00025375531521528084,
0.00020217319307960657,
0.00020191027710344742,
3.975726348731663e-05,
2.0401960690100178e-05
],
"train-error-mean": [
0.0506682,
0.0213034,
0.009941799999999999,
0.014125599999999999,
0.0059878,
0.0020344,
0.0012284,
0.0012284,
0.0009212,
0.0006142
],
"train-error-std": [
0.009200997193782855,
0.0020556122786167634,
0.006076479256938181,
0.0017057689878761427,
0.0018779069625516596,
0.001469605198684327,
0.00026026494193417596,
0.00026026494193417596,
0.0005061973528180487,
0.000506318634853587
],
"train-rmse-mean": [
0.2305582,
0.1598156,
0.11178239999999999,
0.104002,
0.07845200000000001,
0.05541159999999999,
0.0442974,
0.04090819999999999,
0.0325856,
0.0305786
],
"train-rmse-std": [
0.0037922719522734682,
0.01125541159798254,
0.002488777016930202,
0.004049203625405865,
0.0013676392799272755,
0.0037808676041353262,
0.0029096398127603355,
0.002357843285716845,
0.006475110519520113,
0.005710078619423729
],
"test-auc-mean": [
0.9609932000000001,
0.9847534,
0.9972156,
0.9985745999999999,
0.9990005999999999,
0.9992364,
0.9993892000000001,
0.9995353999999999,
0.9999202,
0.9999590000000002
],
"test-auc-std": [
0.01004388481415432,
0.0073321434137637925,
0.001719973674217141,
0.0008885571675474929,
0.0007358270448957148,
0.0007796844489920071,
0.0006559303011753682,
0.00040290524940736174,
9.667760857612425e-05,
4.69212105555625e-05
],
"test-error-mean": [
0.055731600000000006,
0.021188400000000003,
0.009978599999999999,
0.0144336,
0.006294800000000001,
0.0016885999999999997,
0.001228,
0.001228,
0.001228,
0.001228
],
"test-error-std": [
0.015888666194492227,
0.0036532266614597024,
0.004827953421482027,
0.003517125508138713,
0.0031231752688570006,
0.0005741844999649501,
0.0010409403441119958,
0.0010409403441119958,
0.0010409403441119958,
0.0010409403441119958
],
"test-rmse-mean": [
0.2349068,
0.1596758,
0.113232,
0.10586340000000001,
0.08146379999999999,
0.054961800000000005,
0.04472659999999999,
0.040870399999999994,
0.0378632,
0.0355032
],
"test-rmse-std": [
0.009347684363520205,
0.013838324138420807,
0.006189637210693368,
0.008622184470306812,
0.007872053592297248,
0.009846015506792583,
0.010218353460318349,
0.0130493180756697,
0.013843686595701305,
0.013998153155327313
]
}
看rmse效果,说明xgboost的分类效果还是不错的。