from sklearn import tree
import pydotplus
def cart_skl_test():
df = pd.read_csv("../dataSet/liquefaction_data_MLE.csv")
x = df[['CSR', 'Vs']]
y = df['target']
clf = tree.DecisionTreeClassifier()
clf.fit(x, y)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png("cartTree.png")
import pandas as pd
import math
def get_gini(dataSet):
num_instances = len(dataSet) # 数据个数
label_counts = {} # 统计当前各标签数据量
for featVec in dataSet:
current_label = featVec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
sum_prob = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_instances
sum_prob = sum_prob + math.pow(prob, 2)
gini = 1 - sum_prob
return gini
d a t a S e t dataSet dataSet 是数据集, a x i s axis axis 是第几个特征, v a l u e value value 是该特征的取值。
该函数是根据数据集中第 a x i s axis axis个特征的值与 v a l u e value value值比较,对数据进行划分。
def splitDataSet(dataSet, axis, value):
leftDataSet = []
rightDataSet = []
for featVec in dataSet:
if featVec[axis] <= value:
leftDataSet.append(featVec)
else:
rightDataSet.append(featVec)
# print(leftDataSet)
# print(rightDataSet)
return leftDataSet, rightDataSet
def chooseBestFeatureToSplit(dataSet):
# 决策属性不算
numFeatures = len(dataSet[0]) - 1
bestInfoGini = 1.0
bestFeature = -1
bestSplitValue = -1
baseGini = get_gini(dataSet)
for i in range(numFeatures):
# 把第i列属性的值取出来生成一维数组
featList = [example[i] for example in dataSet]
# 剔除重复值,并排序
uniqueVals = list(set(featList))
uniqueVals.sort()
featureSplit = -1 # 当前属性下的最佳分割点
featureGini = 1.0 # 当前属性下的最小Gini值
# 选择当前属性下的最佳分割点
for j in range(len(uniqueVals) - 1):
value = (uniqueVals[j] + uniqueVals[j+1]) / 2
left_dataSet, right_dataSet = splitDataSet(dataSet, i, value)
prob = len(left_dataSet) / float(len(dataSet))
currentGini = prob * get_gini(left_dataSet) + (1 - prob) * get_gini(right_dataSet)
if currentGini < featureGini:
featureGini = currentGini
featureSplit = value
# 选择最佳属性及其分割点
if featureGini < bestInfoGini:
bestInfoGini = featureGini
bestFeature = i
bestSplitValue = featureSplit
print("bestFeature: {}, bestSplitValue: {}, Gini: {}".format(bestFeature, bestSplitValue, baseGini))
return bestFeature, bestSplitValue, bestInfoGini
def createTree(dataSet, paraFeatureName):
# 拷贝标签
classList = [example[-1] for example in dataSet]
# 当结点中所有标签相同时-->叶子结点
if classList.count(classList[0]) == len(classList):
return classList[0]
bestFeat, bestSplit, gini = chooseBestFeatureToSplit(dataSet)
bestFeatureName = paraFeatureName[bestFeat]
myTree = {bestFeatureName: {}} #运用字典存储树
# 递归建立树
leftTree, rightTree = splitDataSet(dataSet, bestFeat, bestSplit)
myTree[bestFeatureName]["<=" + str(bestSplit)] = createTree(leftTree, paraFeatureName)
myTree[bestFeatureName][">" + str(bestSplit)] = createTree(rightTree, paraFeatureName)
return myTree
if __name__ == "__main__":
# cart_skl_test()
df = pd.read_csv("../dataSet/liquefaction_data_MLE.csv") #读取.csv数据
featureName = df.columns.values
dataSet = []
for i in df.values:
dataSet.append(i)
tree = createTree(dataSet, featureName)
print(tree)
{'Vs':
{'<=16.35':
{'Vs':
{'<=6.550000000000001':
0.0,
'>6.550000000000001':
{'Vs':
{'<=11.2':
1.0,
'>11.2':
{'Vs':
{'<=15.55':
{'CSR':
{'<=0.16999999999999998':
0.0,
'>0.16999999999999998':
{'CSR':
{'<=0.26':
{'Vs':
{'<=11.45':
0.0,
'>11.45':
1.0}},
'>0.26':
{'Vs':
{'<=13.25':
1.0,
'>13.25':
0.0}}}}}},
'>15.55': 1.0}}}}}},
'>16.35':
{'CSR':
{'<=0.29500000000000004':
0.0,
'>0.29500000000000004':
{'CSR':
{'<=0.32999999999999996':
1.0,
'>0.32999999999999996':
0.0}}}}}}
{'<=0.29500000000000004':
0.0,
'>0.29500000000000004':
{'CSR':
{'<=0.32999999999999996':
1.0,
'>0.32999999999999996':
0.0}}}}}}