机器学习第一个练手程序 基于决策树的iris数据预测

决策树分类器易于可视化并且易于理解。

iris数据集http://archive.ics.uci.edu/ml/datasets/Iris

这个数据集是非常典型的分类人工数据集,有3类花,每个数据有4个特征(sepal lenght,sepal width,petal length,petal width),每一类花有50个,所以这个数据集有150个数据。

我们的目标是:

(1)导入数据集

(2)训练一个分类器

(3)预测新数据的标签

(4)可视化决策树

python 的sklearn模块里可以直接导入iris数据集:


from sklearn.datasets import  load_iris

iris=load_iris()

print iris.feature_names       #打印特征

print iris.target_names             

#打印类别信息

print iris.data[0]     #打印第一行数据

完整的代码如下:

import numpy as np
from sklearn.datasets import load_iris
from sklearn import tree
iris=load_iris()


test_idx=[0,50,100]

#training data
train_target=np.delete(iris.target,test_idx)
train_data=np.delete(iris.data,test_idx,axis=0)

#testing data

test_target=iris.target[test_idx]
test_data=iris.data[test_idx]

clf=tree.DecisionTreeClassifier()
clf.fit(train_data,train_target)


print test_target    #ground truth label of test data
print clf.predict(test_data)  # the prediction of decision tree
下面是可视化决策树的代码  我的python版本是2.7.9   

#viz code
from sklearn.externals.six import StringIO
import pydot
dot_data=StringIO()
tree.export_graphviz(clf, out_file=dot_data, 
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         impurity=False)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph[0].write_pdf("iris.pdf")
需要安装pydot 和grahphviz模块, 其中安装graphviz有些麻烦,请看http://m.blog.csdn.net/article/details?id=49472949 里给出的详细安装过程~



你可能感兴趣的:(python,机器学习)