决策树是一种监督学习算法,使用样本数据针对数据属性建立决策树模型,根据决策树对测试数据进行分类。
package Tree;
import java.util.ArrayList;
import java.util.List;
/*
* 数据结点类
*/
public class Node implements Cloneable{
private List data; //特征值
private String type; //类型
public Node() {
super();
}
public Node(List data, String type) {
super();
this.data = data;
this.type = type;
}
public List getData() {
return data;
}
public void setData(List data) {
this.data = data;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public Node clone(){
Node node=null;
try {
node=(Node)super.clone();
node.type=this.type;
node.data=new ArrayList();
for(int i=0;i
然后定义决策树的结点TreeNode类,代码如下:
package Tree;
import java.util.List;
/*
* 树节点
*/
public class TreeNode {
private String attribute; //特征
private List branches; //分支数组
public TreeNode() {
super();
}
public TreeNode(String attribute, List branches) {
super();
this.attribute = attribute;
this.branches = branches;
}
public String getAttribute() {
return attribute;
}
public void setAttribute(String attribute) {
this.attribute = attribute;
}
public List getBranches() {
return branches;
}
public void setBranches(List branches) {
this.branches = branches;
}
public String toString(){
String data="";
data+=attribute+":{";
if(branches!=null){
List list=(List) branches;
for(BranchNode node:list){
if(null!=node.getSubTree())
data+=node.getSubTree().toString();
}
}
data+="}";
return data;
}
}
package Tree;
import java.util.List;
/*
* 树节点
*/
public class TreeNode {
private String attribute; //属性值
private List branches; //分支数组
public TreeNode() {
super();
}
public TreeNode(String attribute, List branches) {
super();
this.attribute = attribute;
this.branches = branches;
}
public String getAttribute() {
return attribute;
}
public void setAttribute(String attribute) {
this.attribute = attribute;
}
public List getBranches() {
return branches;
}
public void setBranches(List branches) {
this.branches = branches;
}
public String toString(){
String data="";
data+=attribute+":{";
if(branches!=null){
List list=(List) branches;
for(BranchNode node:list){
if(null!=node.getSubTree())
data+=node.getSubTree().toString();
}
}
data+="}";
return data;
}
}
package Tree;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import KNN.KNN;
public class Tree {
/*
* 通过指定文件路径,加载样本数据
*/
public List loadData(String fileName){
List list=new ArrayList();
BufferedReader br;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(fileName)));
String line=null;
Node node=null;
List data=null;
while((line=br.readLine())!=null){
node=new Node();
data=new ArrayList();
String[] s=line.split("\t");
for(int i=0;i dataSet){
int num=dataSet.size();
HashMap map=new HashMap();
for(Node node : dataSet){
if(map.containsKey(node.getType()))
map.put(node.getType(),map.get(node.getType())+1);
else
map.put(node.getType(), 1);
}
double shannoEnt=0;
double prob;
Set set=map.keySet();
for(String s : set){
prob=(float)map.get(s)/num;
shannoEnt-=prob*(Math.log(prob)/Math.log(2));
}
return shannoEnt;
}
/*
* 划分数据集
*/
public List splitDataSet(List dataSet,int x,String value){
List newSet=new ArrayList();
Node newNode=null;
for(Node node :dataSet){
if(node.getData().get(x).equals(value)){
newNode=node.clone();
newNode.getData().remove(x);
newSet.add(newNode);
}
}
return newSet;
}
/*
* 查找最佳划分特征
*/
public int chooseBestFeatureToSplit(List dataSet){
int featNum=dataSet.get(0).getData().size(); //数据集的特征数量
double oldShan=calShannonEnt(dataSet); //当前数据集的熵
int bestFeat=-1; //最佳特征下标
double bestInfoGain=0.0; //最大信息增益
for(int i=0;i set=new HashSet();
for(Node node :dataSet){
set.add(node.getData().get(i));
}
double prob;
double newEntropy = 0;
for(String value: set){
List newSet=splitDataSet(dataSet, i, value);
prob=(float)newSet.size()/dataSet.size();
newEntropy+=prob*calShannonEnt(newSet);
}
if((oldShan-newEntropy)>=bestInfoGain){
bestInfoGain=oldShan-newEntropy;
bestFeat=i;
}
}
return bestFeat;
}
/*
* 查找次数最高的分类
*/
public String majorityCnt(List classList){
HashMap map=new HashMap();
ValueComparetor vc=new ValueComparetor(map);
TreeMap tm=new TreeMap();
for(String s:classList){
if(map.containsKey(s))
map.put(s, map.get(s)+1);
else
map.put(s, 1);
}
tm.putAll(map);
return tm.firstKey();
}
class ValueComparetor implements Comparator{
Map map;
public ValueComparetor(Map map ){
this.map=map;
}
public int compare(String arg0, String arg1) {
if(map.get(arg0)>=map.get(arg1))
return -1;
else
return 1;
}
}
/*
* 对给定数据集构建决策树
*/
public TreeNode createTree(List dataSet,List labels){
TreeNode tree=null;
List branches=new ArrayList();
List classList=getClassList(dataSet); //获取当前数据集的分类数组
if(dataSet.get(0).getData().size()==0){ //当数据集特征划分完
tree=new TreeNode(majorityCnt(classList),null);
return tree;
}
Set set=new HashSet(classList);
if(set.size()==1){ //当前数据集属于同一分类
tree=new TreeNode(classList.get(0),null );
return tree;
}
int bestFeat=chooseBestFeatureToSplit(dataSet); //获取最佳分组特征下标
String bestFeatLabel=labels.get(bestFeat); //获取最佳分组特征名
labels.remove(bestFeat); //移除分组特征名
Set labelDataSet=new HashSet(); //去重特征值数组
for(Node node:dataSet){
labelDataSet.add(node.getData().get(bestFeat));
}
for(String value : labelDataSet){ //针对特征的不同特征值划分数据集
List subLabels=new ArrayList(labels);
branches.add(new BranchNode(value, createTree(splitDataSet(dataSet, bestFeat, value), subLabels)));
}
tree=new TreeNode(bestFeatLabel, branches);
return tree;
}
/*
* 获取数据集的分类数组
*/
public List getClassList(List dataSet){
List classList=new ArrayList();
for(Node node :dataSet){
classList.add(node.getType());
}
return classList;
}
/*
* 根据决策树对测试数据分类
*/
public String clarrify(TreeNode tree,List labelsList,Node test){
if(tree.getBranches()==null){
return tree.getAttribute();
}
List branches=(List) tree.getBranches();
for(BranchNode branch : branches){
int index=labelsList.indexOf(tree.getAttribute());
if(branch.getValue().equals(test.getData().get(index)))
return clarrify( branch.getSubTree(),labelsList,test);
}
return null;
}
public static void main(String[] args) {
String sampFile="E:\\Java_Project\\DeepLearning\\src\\Tree\\lenses.txt";
Tree t=new Tree();
List data=t.loadData(sampFile);
String []labels={"age","prescript1", "astigmatic1", "tearRate1"};//特征名数组
List labelsList=new ArrayList();
Collections.addAll(labelsList, labels);
TreeNode tree=t.createTree(data, labelsList);
System.out.println("Tree"+tree); //输出决策树
String []testData={"presbyopic","myope","yes","normal"};
List test=new ArrayList();
Collections.addAll(test,testData);
Node node=new Node(test,null);
List labelsList2=new ArrayList();
Collections.addAll(labelsList2, labels);
System.out.println(t.clarrify(tree, labelsList2, node));
}
}
young myope no reduced no lenses
young myope no normal soft
young myope yes reduced no lenses
young myope yes normal hard
young hyper no reduced no lenses
young hyper no normal soft
young hyper yes reduced no lenses
young hyper yes normal hard
pre myope no reduced no lenses
pre myope no normal soft
pre myope yes reduced no lenses
pre myope yes normal hard
pre hyper no reduced no lenses
pre hyper no normal soft
pre hyper yes reduced no lenses
pre hyper yes normal no lenses
presbyopic myope no reduced no lenses
presbyopic myope no normal no lenses
presbyopic myope yes reduced no lenses
presbyopic myope yes normal hard
presbyopic hyper no reduced no lenses
presbyopic hyper no normal soft
presbyopic hyper yes reduced no lenses
presbyopic hyper yes normal no lenses