此前已经上传了ID3决策树的Java实现,C4.5整体架构与之相差不大。
可参考:http://blog.csdn.net/xiaohukun/article/details/78041676
此次将结点的实现由Dom4J改为自定义类实现,更加自由和轻便。
代码已打包并上传
数据仍采用ARFF格式
train.arff
@relation weather.symbolic
@attribute outlook {sunny,overcast,rainy}
@attribute temperature {hot,mild,cool}
@attribute humidity {high,normal}
@attribute windy {TRUE,FALSE}
@attribute play {yes,no}
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
C4.5类(主类)
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.io.FileOutputStream;
import java.io.BufferedOutputStream;
import java.lang.Math.*;
public class DecisionTree {
private ArrayList train_AttributeName = new ArrayList(); // 存储训练集属性的名称
private ArrayList> train_attributeValue = new ArrayList>(); // 存储训练集每个属性的取值
private ArrayList trainData = new ArrayList(); // 训练集数据 ,即arff文件中的data字符串
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
//正则表达,其中*? 表示重复任意次,但尽可能少重复,防止匹配到更后面的"}"符号
private int decatt; // 决策变量在属性集中的索引(即类标所在列)
private InfoGain infoGain;
private TreeNode root;
public void train(String data_path, String targetAttr){
//模型初始化操作
read_trainARFF(new File(data_path));
//printData();
setDec(targetAttr);
infoGain=new InfoGain(trainData, decatt);
//拼装行与列
LinkedList ll=new LinkedList(); //LinkList用于增删比ArrayList有优势
for(int i = 0; i< train_AttributeName.size(); i++){
if(i!=decatt) ll.add(i); //防止类别变量不在最后一列发生错误
}
ArrayList al=new ArrayList();
for(int i=0;i//构建决策树
root = buildDT("root", "null", al, ll);
//剪枝
cutBranch(root);
}
/**
* 构建决策树
* @param fatherName 节点名称
* @param fatherValue 节点值
* @param subset 数据行子集
* @param subset 数据列子集
* @return 返回根节点
*/
public TreeNode buildDT(String fatherName, String fatherValue, ArrayList subset,LinkedList selatt){
TreeNode node=new TreeNode();
Map targetNum = infoGain.get_AttributeNum(subset,decatt);//计算类-频率
String targetValue=infoGain.get_targetValue(targetNum);//判定分类
node.setTargetNum(targetNum);
node.setAttributeName(fatherName);
node.setAttributeValue(fatherValue);
node.setTargetValue(targetValue);
//终止条件为类标单一/树深度达到特征长度(还有可能是信息增益率不存在)
if (infoGain.isPure(targetNum) | selatt.isEmpty() ) {
node.setNodeType("leafNode");
return node;
}
int maxIndex = infoGain.getGainRatioMax(subset,selatt);
selatt.remove(new Integer(maxIndex)); //这样可以remove object
String childName = train_AttributeName.get(maxIndex);
Map> childSubset = infoGain.get_AttributeSubset(subset, maxIndex);
ArrayList childNode = new ArrayList();
for (String childValue : childSubset.keySet()){
TreeNode child = buildDT(childName, childValue, childSubset.get(childValue), selatt);
child.setFatherTreeNode(node); //顺序很重要:回溯
childNode.add(child);
}
node.setChildTreeNode(childNode);
return node;
}
/**
* 剪枝函数
* @param node 判断结点
* @return 剪枝之后的叶子结点集
*/
public ArrayList<int[]> cutBranch(TreeNode node){
ArrayList<int[]> resultNum = new ArrayList<int[]>();
if (node.getNodeType() =="leafNode"){
int[] tempNum = get_leafNum(node);
resultNum.add(tempNum);
return resultNum;
}else{
int sumNum = 0;
double oldRatio = 0;
for (TreeNode child : node.getChildTreeNode()){
for(int[] leafNum : cutBranch(child)){
resultNum.add(leafNum);
oldRatio += 0.5 + leafNum[0];
sumNum += leafNum[1];
}
}
double oldNum =oldRatio;
oldRatio /= sumNum;
double sd = Math.sqrt(sumNum*oldRatio*(1-oldRatio));
int temLeaf[] = get_leafNum(node);
double newNum = temLeaf[0] + 0.5;
if(newNum < oldNum + sd){//符合剪枝条件,剪枝并返回本身
node.setChildTreeNode(null);
node.setNodeType("leafNode");
resultNum.clear();
resultNum.add(temLeaf);
}//不符合剪枝条件,返回叶子结点
return resultNum;
}
}
//获得叶子结点的数目
public int[] get_leafNum(TreeNode node){
int[] resultNum= new int[2];
Map targetNum = node.getTargetNum();
int minNum = Integer.MAX_VALUE;
int sumNum = 0;
for(int num : targetNum.values()){
minNum = Integer.min(minNum, num);
sumNum += num;
}
if (targetNum.size() == 1) minNum = 0;
resultNum[0] = minNum;
resultNum[1] = sumNum;
return resultNum;
}
/**
* 读取arff文件,给attribute、attributevalue、data赋值
* @param file 传入的文件
*/
public void read_trainARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
train_AttributeName.add(matcher.group(1).trim()); //获取第一个括号里的内容
//涉及取值,尽量加.trim(),后面也可以看到,即使是换行符也可能会造成字符串不相等
String[] values = matcher.group(2).split(",");
ArrayList al = new ArrayList(values.length);
for (String value : values) {
al.add(value.trim());
}
train_attributeValue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if(line=="")
continue;
String[] row = line.split(",");
trainData.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 打印Data
*/
public void printData(){
System.out.println("当前的ATTR为");
for(String attr : train_AttributeName){
System.out.print(attr+" ");
}
System.out.println();
System.out.println("---------------------------------");
System.out.println("当前的DATA为");
for(String[] row: trainData){
for (String value : row){
System.out.print(value+" ");
}
System.out.println();
}
System.out.println("---------------------------------");
}
//将决策树存储到xml文件中
public void write_DecisionTree(String filename) {
try {
File file = new File(filename);
if (!file.exists())
file.createNewFile();
FileOutputStream fs = new FileOutputStream(filename);
BufferedOutputStream bos = new BufferedOutputStream(fs);
write_Node(bos, root, "");
bos.flush();
bos.close();
fs.close();
}catch (IOException e){
e.printStackTrace();
}
}
private void write_Node(BufferedOutputStream bos, TreeNode node, String block){
String outputWords1 = block + "<" + node.getAttributeName()+ " value=\"" + node.getAttributeValue() + "\"";
String outputWords2;
Map targetNum = node.getTargetNum();
for (String value : targetNum.keySet()){
outputWords1 += " " + value + ":" + targetNum.get(value);
}
outputWords1 += ">";
if(node.getNodeType()=="leafNode"){
outputWords1 += node.getTargetValue();
outputWords2 = "" + node.getAttributeName() + ">" + "\n";
}else{
outputWords1 += "\n";
outputWords2 = block + "" + node.getAttributeName() + ">" + "\n";
}
try {
bos.write(outputWords1.getBytes());
}catch (IOException e){
e.printStackTrace();
}
ArrayList childNode=node.getChildTreeNode();
if (childNode !=null){
for (TreeNode child : childNode){
write_Node(bos, child, block+" ");
}
}
try {
bos.write(outputWords2.getBytes());
}catch (IOException e){
System.out.println(e.getMessage());
}
}
//设置决策变量
public void setDec(int n) {
if (n < 0 || n >= train_AttributeName.size()) {
System.err.println("决策变量指定错误。");
System.exit(2);
}
decatt = n;
}
public void setDec(String targetAttr) {
int n = train_AttributeName.indexOf(targetAttr);
setDec(n);
}
public static void main(String[] args) {
DecisionTree dt=new DecisionTree();
dt.train("files/train.arff", "play");
dt.write_DecisionTree("files/Tree.xml");
}
}
节点类
import java.util.ArrayList;
import java.util.Map;
/**
* 节点类
*/
public class TreeNode {
private String nodeType;
private String attributeName;
private String attributeValue;
private ArrayList childTreeNode;
private TreeNode fatherTreeNode;
private Map targetNum;
private String targetValue;
//private List pathName;
public TreeNode(){
}
public String getNodeType() {
return nodeType;
}
public void setNodeType(String nodeType) {
this.nodeType = nodeType;
}
public String getAttributeName() {
return attributeName;
}
public void setAttributeName(String attributeName) {
this.attributeName = attributeName;
}
public String getAttributeValue() {
return attributeValue;
}
public void setAttributeValue(String attributeValue) {
this.attributeValue = attributeValue;
}
public ArrayList getChildTreeNode() {
return childTreeNode;
}
public void setChildTreeNode(ArrayList childTreeNode) {
this.childTreeNode = childTreeNode;
}
public TreeNode getFatherTreeNode() {
return fatherTreeNode;
}
public void setFatherTreeNode(TreeNode fatherTreeNode) {
this.fatherTreeNode = fatherTreeNode;
}
public Map getTargetNum() {
return targetNum;
}
public void setTargetNum(Map targetNum) {
this.targetNum = targetNum;
}
public String getTargetValue() {
return targetValue;
}
public void setTargetValue(String targetValue) {
this.targetValue = targetValue;
}
}
信息熵相关类
import java.util.*;
/**
* 信息增益相关类
*/
public class InfoGain {
private ArrayList trainData;
private int decatt;
public InfoGain(ArrayList trainData, int decatt){
this.trainData=trainData;
this.decatt=decatt;
}
/**
* 计算信息熵
*/
public double getEntropy(Map attributeNum){
double entropy = 0.0;
int sum= 0;
for(int num:attributeNum.values()){
sum+=num;
entropy += (-1) * num * Math.log(num+Double.MIN_VALUE)/Math.log(2); //避免log1
}
entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
entropy /= sum;
return entropy;
}
public double getEntropy(ArrayList subset, int attributeIndex){
Map attributeNum = get_AttributeNum(subset,attributeIndex);
double entropy = getEntropy(attributeNum);
return entropy;
}
//信息熵增益率相关
public int getGainRatioMax(ArrayList subset, LinkedList selatt){
//计算原信息熵
Map old_TargetNum = get_AttributeNum(subset, decatt);
double oldEntropy = getEntropy(old_TargetNum);
double maxGainRatio=0;
int maxIndex=decatt;
for(int attributeIndex: selatt){
Map> attributeSubset = get_AttributeSubset(subset, attributeIndex);
int sum = 0;
double newEntropy = 0;
for(ArrayList tempSubset: attributeSubset.values()){
int num = tempSubset.size();
sum += num;
double tempEntropy = getEntropy(tempSubset,decatt);
newEntropy += num * tempEntropy;
}
newEntropy /= sum;
double tempGainRatio = (oldEntropy - newEntropy)/getEntropy(subset, attributeIndex); //计算信息增益率
//如果信息增益率为负,应该停止分支,此处避免麻烦没有做进一步讨论。
if(tempGainRatio > maxGainRatio){
maxGainRatio = tempGainRatio;
maxIndex = attributeIndex;
}
}
return maxIndex;
}
/**
* 判断分类是否唯一
* @param targetNum 各类数目的map
* @return 分类是否唯一标识
*/
public boolean isPure(Map targetNum){
if (targetNum.size()>1){
return false;
}
return true;
}
/**
* 获得对应数据子集的对应特征的值-频率字典
* @param subset 子集行数
* @param attributeIndex 特征列
* @return
*/
public Map get_AttributeNum(ArrayList subset, int attributeIndex ) {
Map attributeNum=new HashMap();
for (int subsetIndex : subset) {
String value=trainData.get(subsetIndex)[attributeIndex];
Integer count = attributeNum.get(value);//int无法使用count!=null
attributeNum.put(value, count!=null ? ++count:1);
}
return attributeNum;
}
/**
* 获得数据在某一特征维度下的子集划分
* @param subset 原子集
* @param attributeIndex 特征序号
* @return 子集划分map
*/
public Map> get_AttributeSubset(ArrayList subset, int attributeIndex){
Map> attributeSubset=new HashMap>();
for (int subsetIndex : subset) {
String value=trainData.get(subsetIndex)[attributeIndex];
ArrayList tempSubset = attributeSubset.get(value);
if(tempSubset != null){
tempSubset.add(subsetIndex);
}else{
tempSubset=new ArrayList();
tempSubset.add(subsetIndex);
}
attributeSubset.put(value,tempSubset);
}
return attributeSubset;
}
/**
* 根据类-数目,判读分类结果
* @param targetNum
* @return
*/
public String get_targetValue(Map targetNum){
int maxNum=0;
String targetValue="";
for(String key: targetNum.keySet()){
int tempNum=targetNum.get(key);
if(tempNum>maxNum){
maxNum=tempNum;
targetValue=key;
}
}
return targetValue;
}
}
决策树属于比较基本的分类算法,但是在编写代码的过程中,我对于迭代的运用和代码实现有了更进一步地认识。
在C4.5中有两块工作比较重要和复杂,其一,自然是生成决策树;其二,便是实现剪枝。
这二者都是通过迭代来实现的,并且都经历了uptodown和downtoup,只不过前者是在自上而下的过程中完成主要操作,回溯只是用以获得返回的结点;而后者的自上而下只是为了找到各个叶子结点,真正的剪枝工作是在回溯的过程实现的。
此次的代码中并没有实现对连续特征的处理以及缺失值的处理。
后者根据具体的情况变化较大,而前者根据目前提供的函数应该可以比较方便的实现,也就不再浪费时间了,如果有亲希望保证完整性,可以自行补充。