经典数据挖掘算法(介绍了包括18大数据挖掘在内的多种经典数据挖掘算法)
前言
文章标题的两个概念也许对于许多同学们来说都相对比较陌生,都比较偏向于于理论方面的知识,但是这个算法非常的强大,在很多方面都会存在他的影子。2个概念,1个维特比算法,1个隐马尔可夫模型。你很难想象,输入法的设计也会用到其中的一些知识。
HMM-隐马尔可夫模型
隐马尔可夫模型如果真的要展开来讲,那短短的一篇文章当然无法阐述的清,所以我会以最简单的方式解释。隐马尔可夫模型简称HMM,根据百度百科中的描述,隐马尔可夫模型描述的是一个含有隐含未知参数的马尔可夫模型。模型的本质是从观察的参数中获取隐含的参数信息。一般的在马尔可夫模型中,前后之间的特征会存在部分的依赖影响。示例图如下:
隐马尔可夫模型在语音识别中有广泛的应用。其实在输入法中也有用到。举个例子,假设我输入wszs,分别代表4个字,所以观察特征就是w, s, z, s,那么我想挖掘出他所想表达的信息,也就是我想打出的字是什么,可能是"我是张三",又可能是“晚上再说”,这个就是可能的信息,最可能的信息,就会被放在输入选择越靠前的位置。在这里我们就会用到一个叫维特比的算法了。
Viterbi-维特比算法
维特比算法这个名字的产生是以著名科学家维特比命名的,维特比算法是数字通信中非常常用的一种算法。那么维特比算法到底是做什么的呢。简单的说,他是一种特殊的动态规划算法,也就是DP问题。但是这里可不是单纯的寻找最短路径这些问题,可能他是需要寻找各个条件因素下最大概率的一条路径,假设针对观察特征,会有多个隐含特征值的情况。比如下面这个是多种隐含变量的组合情况,形成了一个密集的篱笆网络。
于是问题就转变成了,如何在这么多的路径中找到最佳路径。如果这是输入法的例子,上面的每一列的值就是某个拼音下对应的可能的字。于是我们就很容易联想到可以用dp的思想去做,每次求得相邻变量之间求得后最佳的值,存在下一列的节点上,而不是组合这么多种情况去算。时间复杂度能降低不少。但是在马尔可夫模型中,你还要考虑一些别的因素。所以总的来说,维特比算法就是一种利用动态规划算法寻找最有可能产生观察序列的隐含信息序列,尤其是在类似于隐马尔可夫模型的应用中。
算法实例
下面给出一个实际例子,来说明一下维特比算法到底怎么用,如果你用过动态规划算法,相信一定能很迅速的理解我想表达的意思。下面这个例子讲的是海藻的观察特征与天气的关系,通过观测海藻的特征状态,退出当天天气的状况。当然当天天气的预测还可能受昨天的天气影响,所以这是一个很棒的隐马尔可夫模型问题。问题的描述是下面这段话:
假设连续观察3天的海藻湿度为(Dry,Damp,Soggy),求这三天最可能的天气情况。天气只有三类(Sunny,Cloudy,Rainy),而且海藻湿度和天气有一定的关系。问题具体描述,链接在此
ok,状态转移概率矩阵和混淆矩阵都已给出,详细代码以及相关数据,请点击链接:https://github.com/linyiqun/DataMiningAlgorithm/tree/master/Others/DataMining_Viterbi
直接给出代码解答,主算法类ViterbiTool.java:
package DataMining_Viterbi;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* 维特比算法工具类
*
* @author lyq
*
*/
public class ViterbiTool {
// 状态转移概率矩阵文件地址
private String stmFilePath;
// 混淆矩阵文件地址
private String confusionFilePath;
// 初始状态概率
private double[] initStatePro;
// 观察到的状态序列
public String[] observeStates;
// 状态转移矩阵值
private double[][] stMatrix;
// 混淆矩阵值
private double[][] confusionMatrix;
// 各个条件下的潜在特征概率值
private double[][] potentialValues;
// 潜在特征
private ArrayList potentialAttrs;
// 属性值列坐标映射图
private HashMap name2Index;
// 列坐标属性值映射图
private HashMap index2name;
public ViterbiTool(String stmFilePath, String confusionFilePath,
double[] initStatePro, String[] observeStates) {
this.stmFilePath = stmFilePath;
this.confusionFilePath = confusionFilePath;
this.initStatePro = initStatePro;
this.observeStates = observeStates;
initOperation();
}
/**
* 初始化数据操作
*/
private void initOperation() {
double[] temp;
int index;
ArrayList smtDatas;
ArrayList cfDatas;
smtDatas = readDataFile(stmFilePath);
cfDatas = readDataFile(confusionFilePath);
index = 0;
this.stMatrix = new double[smtDatas.size()][];
for (String[] array : smtDatas) {
temp = new double[array.length];
for (int i = 0; i < array.length; i++) {
try {
temp[i] = Double.parseDouble(array[i]);
} catch (NumberFormatException e) {
temp[i] = -1;
}
}
// 将转换后的值赋给数组中
this.stMatrix[index] = temp;
index++;
}
index = 0;
this.confusionMatrix = new double[cfDatas.size()][];
for (String[] array : cfDatas) {
temp = new double[array.length];
for (int i = 0; i < array.length; i++) {
try {
temp[i] = Double.parseDouble(array[i]);
} catch (NumberFormatException e) {
temp[i] = -1;
}
}
// 将转换后的值赋给数组中
this.confusionMatrix[index] = temp;
index++;
}
this.potentialAttrs = new ArrayList<>();
// 添加潜在特征属性
for (String s : smtDatas.get(0)) {
this.potentialAttrs.add(s);
}
// 去除首列无效列
potentialAttrs.remove(0);
this.name2Index = new HashMap<>();
this.index2name = new HashMap<>();
// 添加名称下标映射关系
for (int i = 1; i < smtDatas.get(0).length; i++) {
this.name2Index.put(smtDatas.get(0)[i], i);
// 添加下标到名称的映射
this.index2name.put(i, smtDatas.get(0)[i]);
}
for (int i = 1; i < cfDatas.get(0).length; i++) {
this.name2Index.put(cfDatas.get(0)[i], i);
}
}
/**
* 从文件中读取数据
*/
private ArrayList readDataFile(String filePath) {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
return dataArray;
}
/**
* 根据观察特征计算隐藏的特征概率矩阵
*/
private void calPotencialProMatrix() {
String curObserveState;
// 观察特征和潜在特征的下标
int osIndex;
int psIndex;
double temp;
double maxPro;
// 混淆矩阵概率值,就是相关影响的因素概率
double confusionPro;
this.potentialValues = new double[observeStates.length][potentialAttrs
.size() + 1];
for (int i = 0; i < this.observeStates.length; i++) {
curObserveState = this.observeStates[i];
osIndex = this.name2Index.get(curObserveState);
maxPro = -1;
// 因为是第一个观察特征,没有前面的影响,根据初始状态计算
if (i == 0) {
for (String attr : this.potentialAttrs) {
psIndex = this.name2Index.get(attr);
confusionPro = this.confusionMatrix[psIndex][osIndex];
temp = this.initStatePro[psIndex - 1] * confusionPro;
this.potentialValues[BaseNames.DAY1][psIndex] = temp;
}
} else {
// 后面的潜在特征受前一个特征的影响,以及当前的混淆因素影响
for (String toDayAttr : this.potentialAttrs) {
psIndex = this.name2Index.get(toDayAttr);
confusionPro = this.confusionMatrix[psIndex][osIndex];
int index;
maxPro = -1;
// 通过昨天的概率计算今天此特征的最大概率
for (String yAttr : this.potentialAttrs) {
index = this.name2Index.get(yAttr);
temp = this.potentialValues[i - 1][index]
* this.stMatrix[index][psIndex];
// 计算得到今天此潜在特征的最大概率
if (temp > maxPro) {
maxPro = temp;
}
}
this.potentialValues[i][psIndex] = maxPro * confusionPro;
}
}
}
}
/**
* 根据同时期最大概率值输出潜在特征值
*/
private void outputResultAttr() {
double maxPro;
int maxIndex;
ArrayList psValues;
psValues = new ArrayList<>();
for (int i = 0; i < this.potentialValues.length; i++) {
maxPro = -1;
maxIndex = 0;
for (int j = 0; j < potentialValues[i].length; j++) {
if (this.potentialValues[i][j] > maxPro) {
maxPro = potentialValues[i][j];
maxIndex = j;
}
}
// 取出最大概率下标对应的潜在特征
psValues.add(this.index2name.get(maxIndex));
}
System.out.println("观察特征为:");
for (String s : this.observeStates) {
System.out.print(s + ", ");
}
System.out.println();
System.out.println("潜在特征为:");
for (String s : psValues) {
System.out.print(s + ", ");
}
System.out.println();
}
/**
* 根据观察属性,得到潜在属性信息
*/
public void calHMMObserve() {
calPotencialProMatrix();
outputResultAttr();
}
}
测试结果输出:
观察特征为:
Dry, Damp, Soggy,
潜在特征为:
Sunny, Cloudy, Rainy,
参考文献
百度百科-隐马尔可夫模型
百度百科-维特比
<<数学之美>>第二版-吴军
http://blog.csdn.net/jeiwt/article/details/8076739
作者:Androidlushangderen 发表于2015/8/3 23:09:39 原文链接
阅读:475 评论:0 查看评论
再学贝叶斯网络--TAN树型朴素贝叶斯算法
2015年7月5日 15:18
前言
在前面的时间里已经学习过了NB朴素贝叶斯算法, 又刚刚初步的学习了贝叶斯网络的一些基本概念和常用的计算方法。于是就有了上篇初识贝叶斯网络的文章,由于本人最近一直在研究学习<<贝叶斯网引论>>,也接触到了许多与贝叶斯网络相关的知识,可以说朴素贝叶斯算法这些只是我们所了解贝叶斯知识的很小的一部分。今天我要总结的学习成果就是基于NB算法的,叫做Tree Augmented Naive Bays,中文意思就是树型朴素贝叶斯算法,简单理解就是树增强型NB算法,那么问题来了,他是如何增强的呢,请继续往下正文的描述。
朴素贝叶斯算法
又得要从朴素贝叶斯算法开始讲起了,因为在前言中已经说了,TAN算法是对NB算法的增强,了解过NB算法的,一定知道NB算法在使用的时候是假设属性事件是相互独立的,而决策属性的分类结果是依赖于各个条件属性的情况的,最后选择分类属性中拥有最大后验概率的值为决策属性。比如下面这个模型可以描述一个简单的模型,
上面账号是否真实的依赖属性条件有3个,好友密度,是否使用真实头像,日志密度,假设这3个属性是相互独立的,但是事实上,在这里的头像是否真实和好友密度其实是有关联的,所以更加真实的情况是下面这张情况;
OK,TAN的出现就解决了条件间的部分属性依赖的问题。在上面的例子中我们是根据自己的主观意识判断出头像和好友密度的关系,但是在真实算法中,我们当然希望机器能够自己根据所给数据集帮我们得出这样的关系,令人高兴的事,TAN帮我们做到了这点。
TAN算法
互信息值
互信息值,在百度百科中的解释如下:
互信息值是信息论中一个有用的信息度量。它可以看出是一个信息量里包含另一个随机变量的信息量。
用图线来表示就是下面这样。
中间的I(x;y)就是互信息值,X,Y代表的2种属性。于是下面这个属性就很好理解了,互信息值越大,就代表2个属性关联性越大。互信息值的标准公式如下:
但是在TAN中会有少许的不一样,会有类变量属性的加入,因为属性之间的关联性的前提是要在某一分类属性确定下进行重新计算,不同的类属性值会有不同的属性关联性。下面是TAN中的I(x;Y)计算公式:
现在看不懂不要紧,后面在给出的程序代码中可自行调试。
算法实现过程
TAN的算法过程其实并不简单,在计算完各个属性对的互信息值之后,要进行贝叶斯网络的构建,这个是TAN中最难的部分,这个部分有下面几个阶段。
1、根据各个属性对的互信息值降序排序,依次取出其中的节点对,遵循不产生环路的原则,构造最大权重跨度树,直到选择完n-1条边为止(因为总共n个属性节点,n-1条边即可确定)。按照互信息值从高到低选择的原因就是要保留关联性更高的关联依赖性的边。
2、上述过程构成的是一个无向图,接下来为整个无向图确定边的方向。选择任意一个属性节点作为根节点,由根节点向外的方向为属性节点之间的方向。
3、为每一个属性节点添加父节点,父节点就是分类属性节点,至此贝叶斯网络结构构造完毕。
为了方便大家理解,我在网上截了几张图,下面这张是在5个属性节点中优先选择了互信息值最大的4条作为无向图:
上述带了箭头是因为,我选择的A作为树的根节点,然后方向就全部确定了,因为A直接连着4个属性节点,然后再此基础上添加父节点,就是下面这个样子了。
OK,这样应该就比较好理解了吧,如果还不理解,请仔细分析我写的程序,从代码中去理解这个过程也可以。
分类结果概率的计算
分类结果概率的计算其实非常简单,只要把查询的条件属性传入分类模型中,然后计算不同类属性下的概率值,拥有最大概率值的分类属性值为最终的分类结果。下面是计算公式,就是联合概率分布公式:
代码实现
测试数据集input.txt:
OutLook Temperature Humidity Wind PlayTennis
Sunny Hot High Weak No
Sunny Hot High Strong No
Overcast Hot High Weak Yes
Rainy Mild High Weak Yes
Rainy Cool Normal Weak Yes
Rainy Cool Normal Strong No
Overcast Cool Normal Strong Yes
Sunny Mild High Weak No
Sunny Cool Normal Weak Yes
Rainy Mild Normal Weak Yes
Sunny Mild Normal Strong Yes
Overcast Mild High Strong Yes
Overcast Hot Normal Weak Yes
Rainy Mild High Strong No
节点类Node.java:
package DataMining_TAN;
import java.util.ArrayList;
/**
* 贝叶斯网络节点类
*
* @author lyq
*
*/
public class Node {
//节点唯一id,方便后面节点连接方向的确定
int id;
// 节点的属性名称
String name;
// 该节点所连续的节点
ArrayList connectedNodes;
public Node(int id, String name) {
this.id = id;
this.name = name;
// 初始化变量
this.connectedNodes = new ArrayList<>();
}
/**
* 将自身节点连接到目标给定的节点
*
* @param node
* 下游节点
*/
public void connectNode(Node node) {
//避免连接自身
if(this.id == node.id){
return;
}
// 将节点加入自身节点的节点列表中
this.connectedNodes.add(node);
// 将自身节点加入到目标节点的列表中
node.connectedNodes.add(this);
}
/**
* 判断与目标节点是否相同,主要比较名称是否相同即可
*
* @param node
* 目标结点
* @return
*/
public boolean isEqual(Node node) {
boolean isEqual;
isEqual = false;
// 节点名称相同则视为相等
if (this.id == node.id) {
isEqual = true;
}
return isEqual;
}
}
互信息值类.java:
package DataMining_TAN;
/**
* 属性之间的互信息值,表示属性之间的关联性大小
* @author lyq
*
*/
public class AttrMutualInfo implements Comparable{
//互信息值
Double value;
//关联属性值对
Node[] nodeArray;
public AttrMutualInfo(double value, Node node1, Node node2){
this.value = value;
this.nodeArray = new Node[2];
this.nodeArray[0] = node1;
this.nodeArray[1] = node2;
}
@Override
public int compareTo(AttrMutualInfo o) {
// TODO Auto-generated method stub
return o.value.compareTo(this.value);
}
}
算法主程序类TANTool.java:
package DataMining_TAN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
/**
* TAN树型朴素贝叶斯算法工具类
*
* @author lyq
*
*/
public class TANTool {
// 测试数据集地址
private String filePath;
// 数据集属性总数,其中一个个分类属性
private int attrNum;
// 分类属性名
private String classAttrName;
// 属性列名称行
private String[] attrNames;
// 贝叶斯网络边的方向,数组内的数值为节点id,从i->j
private int[][] edges;
// 属性名到列下标的映射
private HashMap attr2Column;
// 属性,属性对取值集合映射对
private HashMap> attr2Values;
// 贝叶斯网络总节点列表
private ArrayList totalNodes;
// 总的测试数据
private ArrayList totalDatas;
public TANTool(String filePath) {
this.filePath = filePath;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] array;
while ((str = in.readLine()) != null) {
array = str.split(" ");
dataArray.add(array);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
this.totalDatas = dataArray;
this.attrNames = this.totalDatas.get(0);
this.attrNum = this.attrNames.length;
this.classAttrName = this.attrNames[attrNum - 1];
Node node;
this.edges = new int[attrNum][attrNum];
this.totalNodes = new ArrayList<>();
this.attr2Column = new HashMap<>();
this.attr2Values = new HashMap<>();
// 分类属性节点id最小设为0
node = new Node(0, attrNames[attrNum - 1]);
this.totalNodes.add(node);
for (int i = 0; i < attrNames.length; i++) {
if (i < attrNum - 1) {
// 创建贝叶斯网络节点,每个属性一个节点
node = new Node(i + 1, attrNames[i]);
this.totalNodes.add(node);
}
// 添加属性到列下标的映射
this.attr2Column.put(attrNames[i], i);
}
String[] temp;
ArrayList values;
// 进行属性名,属性值对的映射匹配
for (int i = 1; i < this.totalDatas.size(); i++) {
temp = this.totalDatas.get(i);
for (int j = 0; j < temp.length; j++) {
// 判断map中是否包含此属性名
if (this.attr2Values.containsKey(attrNames[j])) {
values = this.attr2Values.get(attrNames[j]);
} else {
values = new ArrayList<>();
}
if (!values.contains(temp[j])) {
// 加入新的属性值
values.add(temp[j]);
}
this.attr2Values.put(attrNames[j], values);
}
}
}
/**
* 根据条件互信息度对构建最大权重跨度树,返回第一个节点为根节点
*
* @param iArray
*/
private Node constructWeightTree(ArrayList iArray) {
Node node1;
Node node2;
Node root;
ArrayList existNodes;
existNodes = new ArrayList<>();
for (Node[] i : iArray) {
node1 = i[0];
node2 = i[1];
// 将2个节点进行连接
node1.connectNode(node2);
// 避免出现环路现象
addIfNotExist(node1, existNodes);
addIfNotExist(node2, existNodes);
if (existNodes.size() == attrNum - 1) {
break;
}
}
// 返回第一个作为根节点
root = existNodes.get(0);
return root;
}
/**
* 为树型结构确定边的方向,方向为属性根节点方向指向其他属性节点方向
*
* @param root
* 当前遍历到的节点
*/
private void confirmGraphDirection(Node currentNode) {
int i;
int j;
ArrayList connectedNodes;
connectedNodes = currentNode.connectedNodes;
i = currentNode.id;
for (Node n : connectedNodes) {
j = n.id;
// 判断连接此2节点的方向是否被确定
if (edges[i][j] == 0 && edges[j][i] == 0) {
// 如果没有确定,则制定方向为i->j
edges[i][j] = 1;
// 递归继续搜索
confirmGraphDirection(n);
}
}
}
/**
* 为属性节点添加分类属性节点为父节点
*
* @param parentNode
* 父节点
* @param nodeList
* 子节点列表
*/
private void addParentNode() {
// 分类属性节点
Node parentNode;
parentNode = null;
for (Node n : this.totalNodes) {
if (n.id == 0) {
parentNode = n;
break;
}
}
for (Node child : this.totalNodes) {
parentNode.connectNode(child);
if (child.id != 0) {
// 确定连接方向
this.edges[0][child.id] = 1;
}
}
}
/**
* 在节点集合中添加节点
*
* @param node
* 待添加节点
* @param existNodes
* 已存在的节点列表
* @return
*/
public boolean addIfNotExist(Node node, ArrayList existNodes) {
boolean canAdd;
canAdd = true;
for (Node n : existNodes) {
// 如果节点列表中已经含有节点,则算添加失败
if (n.isEqual(node)) {
canAdd = false;
break;
}
}
if (canAdd) {
existNodes.add(node);
}
return canAdd;
}
/**
* 计算节点条件概率
*
* @param node
* 关于node的后验概率
* @param queryParam
* 查询的属性参数
* @return
*/
private double calConditionPro(Node node, HashMap queryParam) {
int id;
double pro;
String value;
String[] attrValue;
ArrayList priorAttrInfos;
ArrayList backAttrInfos;
ArrayList parentNodes;
pro = 1;
id = node.id;
parentNodes = new ArrayList<>();
priorAttrInfos = new ArrayList<>();
backAttrInfos = new ArrayList<>();
for (int i = 0; i < this.edges.length; i++) {
// 寻找父节点id
if (this.edges[i][id] == 1) {
for (Node temp : this.totalNodes) {
// 寻找目标节点id
if (temp.id == i) {
parentNodes.add(temp);
break;
}
}
}
}
// 获取先验属性的属性值,首先添加先验属性
value = queryParam.get(node.name);
attrValue = new String[2];
attrValue[0] = node.name;
attrValue[1] = value;
priorAttrInfos.add(attrValue);
// 逐一添加后验属性
for (Node p : parentNodes) {
value = queryParam.get(p.name);
attrValue = new String[2];
attrValue[0] = p.name;
attrValue[1] = value;
backAttrInfos.add(attrValue);
}
pro = queryConditionPro(priorAttrInfos, backAttrInfos);
return pro;
}
/**
* 查询条件概率
*
* @param attrValues
* 条件属性值
* @return
*/
private double queryConditionPro(ArrayList priorValues,
ArrayList backValues) {
// 判断是否满足先验属性值条件
boolean hasPrior;
// 判断是否满足后验属性值条件
boolean hasBack;
int attrIndex;
double backPro;
double totalPro;
double pro;
String[] tempData;
pro = 0;
totalPro = 0;
backPro = 0;
// 跳过第一行的属性名称行
for (int i = 1; i < this.totalDatas.size(); i++) {
tempData = this.totalDatas.get(i);
hasPrior = true;
hasBack = true;
// 判断是否满足先验条件
for (String[] array : priorValues) {
attrIndex = this.attr2Column.get(array[0]);
// 判断值是否满足条件
if (!tempData[attrIndex].equals(array[1])) {
hasPrior = false;
break;
}
}
// 判断是否满足后验条件
for (String[] array : backValues) {
attrIndex = this.attr2Column.get(array[0]);
// 判断值是否满足条件
if (!tempData[attrIndex].equals(array[1])) {
hasBack = false;
break;
}
}
// 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数
if (hasBack) {
backPro++;
if (hasPrior) {
totalPro++;
}
} else if (hasPrior && backValues.size() == 0) {
// 如果只有先验概率则为纯概率的计算
totalPro++;
backPro = 1.0;
}
}
if (backPro == 0) {
pro = 0;
} else {
// 计算总的概率=都发生概率/只发生后验条件的时间概率
pro = totalPro / backPro;
}
return pro;
}
/**
* 输入查询条件参数,计算发生概率
*
* @param queryParam
* 条件参数
* @return
*/
public double calHappenedPro(String queryParam) {
double result;
double temp;
// 分类属性值
String classAttrValue;
String[] array;
String[] array2;
HashMap params;
result = 1;
params = new HashMap<>();
// 进行查询字符的参数分解
array = queryParam.split(",");
for (String s : array) {
array2 = s.split("=");
params.put(array2[0], array2[1]);
}
classAttrValue = params.get(classAttrName);
// 构建贝叶斯网络结构
constructBayesNetWork(classAttrValue);
for (Node n : this.totalNodes) {
temp = calConditionPro(n, params);
// 为了避免出现条件概率为0的现象,进行轻微矫正
if (temp == 0) {
temp = 0.001;
}
// 按照联合概率公式,进行乘积运算
result *= temp;
}
return result;
}
/**
* 构建树型贝叶斯网络结构
*
* @param value
* 类别量值
*/
private void constructBayesNetWork(String value) {
Node rootNode;
ArrayList mInfoArray;
// 互信息度对
ArrayList iArray;
iArray = null;
rootNode = null;
// 在每次重新构建贝叶斯网络结构的时候,清空原有的连接结构
for (Node n : this.totalNodes) {
n.connectedNodes.clear();
}
this.edges = new int[attrNum][attrNum];
// 从互信息对象中取出属性值对
iArray = new ArrayList<>();
mInfoArray = calAttrMutualInfoArray(value);
for (AttrMutualInfo v : mInfoArray) {
iArray.add(v.nodeArray);
}
// 构建最大权重跨度树
rootNode = constructWeightTree(iArray);
// 为无向图确定边的方向
confirmGraphDirection(rootNode);
// 为每个属性节点添加分类属性父节点
addParentNode();
}
/**
* 给定分类变量值,计算属性之间的互信息值
*
* @param value
* 分类变量值
* @return
*/
private ArrayList calAttrMutualInfoArray(String value) {
double iValue;
Node node1;
Node node2;
AttrMutualInfo mInfo;
ArrayList mInfoArray;
mInfoArray = new ArrayList<>();
for (int i = 0; i < this.totalNodes.size() - 1; i++) {
node1 = this.totalNodes.get(i);
// 跳过分类属性节点
if (node1.id == 0) {
continue;
}
for (int j = i + 1; j < this.totalNodes.size(); j++) {
node2 = this.totalNodes.get(j);
// 跳过分类属性节点
if (node2.id == 0) {
continue;
}
// 计算2个属性节点之间的互信息值
iValue = calMutualInfoValue(node1, node2, value);
mInfo = new AttrMutualInfo(iValue, node1, node2);
mInfoArray.add(mInfo);
}
}
// 将结果进行降序排列,让互信息值高的优先用于构建树
Collections.sort(mInfoArray);
return mInfoArray;
}
/**
* 计算2个属性节点的互信息值
*
* @param node1
* 节点1
* @param node2
* 节点2
* @param vlaue
* 分类变量值
*/
private double calMutualInfoValue(Node node1, Node node2, String value) {
double iValue;
double temp;
// 三种不同条件的后验概率
double pXiXj;
double pXi;
double pXj;
String[] array1;
String[] array2;
ArrayList attrValues1;
ArrayList attrValues2;
ArrayList priorValues;
// 后验概率,在这里就是类变量值
ArrayList backValues;
array1 = new String[2];
array2 = new String[2];
priorValues = new ArrayList<>();
backValues = new ArrayList<>();
iValue = 0;
array1[0] = classAttrName;
array1[1] = value;
// 后验属性都是类属性
backValues.add(array1);
// 获取节点属性的属性值集合
attrValues1 = this.attr2Values.get(node1.name);
attrValues2 = this.attr2Values.get(node2.name);
for (String v1 : attrValues1) {
for (String v2 : attrValues2) {
priorValues.clear();
array1 = new String[2];
array1[0] = node1.name;
array1[1] = v1;
priorValues.add(array1);
array2 = new String[2];
array2[0] = node2.name;
array2[1] = v2;
priorValues.add(array2);
// 计算3种条件下的概率
pXiXj = queryConditionPro(priorValues, backValues);
priorValues.clear();
priorValues.add(array1);
pXi = queryConditionPro(priorValues, backValues);
priorValues.clear();
priorValues.add(array2);
pXj = queryConditionPro(priorValues, backValues);
// 如果出现其中一个计数概率为0,则直接赋值为0处理
if (pXiXj == 0 || pXi == 0 || pXj == 0) {
temp = 0;
} else {
// 利用公式计算针对此属性值对组合的概率
temp = pXiXj * Math.log(pXiXj / (pXi * pXj)) / Math.log(2);
}
// 进行和属性值对组合的累加即为整个属性的互信息值
iValue += temp;
}
}
return iValue;
}
}
场景测试类client.java:
package DataMining_TAN;
/**
* TAN树型朴素贝叶斯算法
*
* @author lyq
*
*/
public class Client {
public static void main(String[] args) {
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
// 条件查询语句
String queryStr;
// 分类结果概率1
double classResult1;
// 分类结果概率2
double classResult2;
TANTool tool = new TANTool(filePath);
queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=No";
classResult1 = tool.calHappenedPro(queryStr);
queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=Yes";
classResult2 = tool.calHappenedPro(queryStr);
System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=No",
classResult1));
System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=Yes",
classResult2));
if (classResult1 > classResult2) {
System.out.println("分类类别为PlayTennis=No");
} else {
System.out.println("分类类别为PlayTennis=Yes");
}
}
}
结果输出:
类别为PlayTennis=No所求得的概率为0.09523809523809525
类别为PlayTennis=Yes所求得的概率为3.571428571428571E-5
分类类别为PlayTennis=No
参考文献
百度百科
贝叶斯网络分类器与应用,作者:余民杰
用于数据挖掘的TAN分类器的研究和应用,作者:孙笑徽等4人
更多数据挖掘算法
https://github.com/linyiqun/DataMiningAlgorithm
作者:Androidlushangderen 发表于2015/7/5 15:18:09 原文链接
阅读:638 评论:0 查看评论
初识贝叶斯网络
2015年6月29日 16:38
前言
一看到贝叶斯网络,马上让人联想到的是5个字,朴素贝叶斯,在所难免,NaiveByes的知名度确实会被贝叶斯网络算法更高一点。其实不管是朴素贝叶斯算法,还是今天我打算讲述的贝叶斯网络算法也罢,归根结底来说都是贝叶斯系列分类算法,他的核心思想就是基于概率学的知识进行分类判断,至于分类得到底准不准,大家尽可以自己用数据集去测试测试。OK,下面进入正题--贝叶斯网络算法。
朴素贝叶斯
一般我在介绍某种算法之前,都事先会学习一下相关的算法,以便于新算法的学习,而与贝叶斯网络算法相关性比较大的在我看来就是朴素贝叶斯算法,而且前段时间也恰好学习过,简单的来说,朴素贝叶斯算法的假设条件是各个事件相互独立,然后利用贝叶斯定理,做概率的计算,于是这个算法的核心就是就是这个贝叶斯定理的运用了喽,不错,贝叶斯定理的确很有用,他是基于条件概率的先验概率和后验概率的转换公式,这么说有点抽象,下面是公式的表达式:
大学里概率学的课本上都有介绍过的,这个公式的好处在于对于一些比较难直接得出的概率通过转换后的概率计算可得,一般是把决策属性值放在先验属性中,当做目标值,然后通过决策属性值的后验概率计算所得。具体请查看我的朴素贝叶斯算法介绍。
贝叶斯网络
下面这个部分就是文章的主题了,贝叶斯网络,里面有2个字非常关键,就是网络,网络代表的潜在意思有2点,第一是有结构的,第二存在关联,我们可以马上联想到DAG有向无环图。不错,存在关联的这个特点就是与朴素贝叶斯算法最大的一个不同点,因为朴素贝叶斯算法在计算概率值上是假设各个事务属性是相互独立的,但是理性的思考一下,其实这个很难做到,任何事务,如果你仔细去想想,其实都还是有点联系的。比如这里有个例子:
在SNS社区中检验账号的真实性
如果用朴素贝叶斯来做的话,就会是这样的假设:
i、真实账号比非真实账号平均具有更大的日志密度、各大的好友密度以及更多的使用真实头像。
ii、日志密度、好友密度和是否使用真实头像在账号真实性给定的条件下是独立的。
但是其实往深入一想,使用真实的头像其实是会提高人家添加你为好友的概率的,所以在这个条件的独立其实是有问题的,所以在贝叶斯网络中是允许关联的存在的,假设就变为如下:
i、真实账号比非真实账号平均具有更大的日志密度、各大的好友密度以及更多的使用真实头像。
ii、日志密度与好友密度、日志密度与是否使用真实头像在账号真实性给定的条件下是独立的。
iii、使用真实头像的用户比使用非真实头像的用户平均有更大的好友密度。
在贝叶斯网络中,会用一张DAG来表示,每个节点代表某个属性事件,每条边代表其中的条件概率,如下:
贝叶斯网络概率的计算
贝叶斯网络概率的计算很简单,是从联合概率分布公式中变换所得,下面是联合概率分布公式:
而在贝叶斯网络中,由于存在前述的关系存在,该公式就被简化为了如下:
其中Parent(xi),表示的是xi的前驱结点,如果还不理解,可以对照我后面的代码,自行调试分析。
代码实现
需要输入2部分的数据,依赖关系,用于构建贝叶斯网络图,第二个是测试数据集,算法总代码地址:
https://github.com/linyiqun/DataMiningAlgorithm/tree/master/Others/DataMining_BayesNetwork
依赖关系数据如下:
B A
E A
A M
A J
测试数据集:
B E A M J P
y y y y y 0.00012
y y y y n 0.000051
y y y n y 0.000013
y y y n n 0.0000057
y y n y y 0.000000005
y y n y n 0.00000049
y y n n y 0.000000095
y y n n n 0.0000094
y n y y y 0.0058
y n y y n 0.0025
y n y n y 0.00065
y n y n n 0.00028
y n n y y 0.00000029
y n n y n 0.000029
y n n n y 0.0000056
y n n n n 0.00055
n y y y y 0.0036
n y y y n 0.0016
n y y n y 0.0004
n y y n n 0.00017
n y n y y 0.000007
n y n y n 0.00069
n y n n y 0.00013
n y n n n 0.013
n n y y y 0.00061
n n y y n 0.00026
n n y n y 0.000068
n n y n n 0.000029
n n n y y 0.00048
n n n y n 0.048
n n n n y 0.0092
n n n n n 0.91
节点类Node.java:
package DataMining_BayesNetwork;
import java.util.ArrayList;
/**
* 贝叶斯网络节点类
*
* @author lyq
*
*/
public class Node {
// 节点的属性名称
String name;
// 节点的父亲节点,也就是上游节点,可能多个
ArrayList parentNodes;
// 节点的子节点,也就是下游节点,可能多个
ArrayList childNodes;
public Node(String name) {
this.name = name;
// 初始化变量
this.parentNodes = new ArrayList<>();
this.childNodes = new ArrayList<>();
}
/**
* 将自身节点连接到目标给定的节点
*
* @param node
* 下游节点
*/
public void connectNode(Node node) {
// 将下游节点加入自身节点的孩子节点中
this.childNodes.add(node);
// 将自身节点加入到下游节点的父节点中
node.parentNodes.add(this);
}
/**
* 判断与目标节点是否相同,主要比较名称是否相同即可
*
* @param node
* 目标结点
* @return
*/
public boolean isEqual(Node node) {
boolean isEqual;
isEqual = false;
// 节点名称相同则视为相等
if (this.name.equals(node.name)) {
isEqual = true;
}
return isEqual;
}
}
算法类BayesNetworkTool.java:
package DataMining_BayesNetwork;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
/**
* 贝叶斯网络算法工具类
*
* @author lyq
*
*/
public class BayesNetWorkTool {
// 联合概率分布数据文件地址
private String dataFilePath;
// 事件关联数据文件地址
private String attachFilePath;
// 属性列列数
private int columns;
// 概率分布数据
private String[][] totalData;
// 关联数据对
private ArrayList attachData;
// 节点存放列表
private ArrayList nodes;
// 属性名与列数之间的对应关系
private HashMap attr2Column;
public BayesNetWorkTool(String dataFilePath, String attachFilePath) {
this.dataFilePath = dataFilePath;
this.attachFilePath = attachFilePath;
initDatas();
}
/**
* 初始化关联数据和概率分布数据
*/
private void initDatas() {
String[] columnValues;
String[] array;
ArrayList datas;
ArrayList adatas;
// 从文件中读取数据
datas = readDataFile(dataFilePath);
adatas = readDataFile(attachFilePath);
columnValues = datas.get(0).split(" ");
// 属性割名称代表事件B(盗窃),E(地震),A(警铃响).M(接到M的电话),J同M的意思,
// 属性值都是y,n代表yes发生和no不发生
this.attr2Column = new HashMap<>();
for (int i = 0; i < columnValues.length; i++) {
// 从数据中取出属性名称行,列数值存入图中
this.attr2Column.put(columnValues[i], i);
}
this.columns = columnValues.length;
this.totalData = new String[datas.size()][columns];
for (int i = 0; i < datas.size(); i++) {
this.totalData[i] = datas.get(i).split(" ");
}
this.attachData = new ArrayList<>();
// 解析关联数据对
for (String str : adatas) {
array = str.split(" ");
this.attachData.add(array);
}
// 构造贝叶斯网络结构图
constructDAG();
}
/**
* 从文件中读取数据
*/
private ArrayList readDataFile(String filePath) {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
while ((str = in.readLine()) != null) {
dataArray.add(str);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
return dataArray;
}
/**
* 根据关联数据构造贝叶斯网络无环有向图
*/
private void constructDAG() {
// 节点存在标识
boolean srcExist;
boolean desExist;
String name1;
String name2;
Node srcNode;
Node desNode;
this.nodes = new ArrayList<>();
for (String[] array : this.attachData) {
srcExist = false;
desExist = false;
name1 = array[0];
name2 = array[1];
// 新建节点
srcNode = new Node(name1);
desNode = new Node(name2);
for (Node temp : this.nodes) {
// 如果找到相同节点,则取出
if (srcNode.isEqual(temp)) {
srcExist = true;
srcNode = temp;
} else if (desNode.isEqual(temp)) {
desExist = true;
desNode = temp;
}
// 如果2个节点都已找到,则跳出循环
if (srcExist && desExist) {
break;
}
}
// 将2个节点进行连接
srcNode.connectNode(desNode);
// 根据标识判断是否需要加入列表容器中
if (!srcExist) {
this.nodes.add(srcNode);
}
if (!desExist) {
this.nodes.add(desNode);
}
}
}
/**
* 查询条件概率
*
* @param attrValues
* 条件属性值
* @return
*/
private double queryConditionPro(ArrayList attrValues) {
// 判断是否满足先验属性值条件
boolean hasPrior;
// 判断是否满足后验属性值条件
boolean hasBack;
int priorIndex;
int attrIndex;
double backPro;
double totalPro;
double pro;
double currentPro;
// 先验属性
String[] priorValue;
String[] tempData;
pro = 0;
totalPro = 0;
backPro = 0;
attrValues.get(0);
priorValue = attrValues.get(0);
// 得到后验概率
attrValues.remove(0);
// 取出先验属性的列数
priorIndex = this.attr2Column.get(priorValue[0]);
// 跳过第一行的属性名称行
for (int i = 1; i < this.totalData.length; i++) {
tempData = this.totalData[i];
hasPrior = false;
hasBack = true;
// 当前行的概率
currentPro = Double.parseDouble(tempData[this.columns - 1]);
// 判断是否满足先验条件
if (tempData[priorIndex].equals(priorValue[1])) {
hasPrior = true;
}
for (String[] array : attrValues) {
attrIndex = this.attr2Column.get(array[0]);
// 判断值是否满足条件
if (!tempData[attrIndex].equals(array[1])) {
hasBack = false;
break;
}
}
// 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数
if (hasBack) {
backPro += currentPro;
if (hasPrior) {
totalPro += currentPro;
}
} else if (hasPrior && attrValues.size() == 0) {
// 如果只有先验概率则为纯概率的计算
totalPro += currentPro;
backPro = 1.0;
}
}
// 计算总的概率=都发生概率/只发生后验条件的时间概率
pro = totalPro / backPro;
return pro;
}
/**
* 根据贝叶斯网络计算概率
*
* @param queryStr
* 查询条件串
* @return
*/
public double calProByNetWork(String queryStr) {
double temp;
double pro;
String[] array;
// 先验条件值
String[] preValue;
// 后验条件值
String[] backValue;
// 所有先验条件和后验条件值的属性值的汇总
ArrayList attrValues;
// 判断是否满足网络结构
if (!satisfiedNewWork(queryStr)) {
return -1;
}
pro = 1;
// 首先做查询条件的分解
array = queryStr.split(",");
// 概率的初值等于第一个事件发生的随机概率
attrValues = new ArrayList<>();
attrValues.add(array[0].split("="));
pro = queryConditionPro(attrValues);
for (int i = 0; i < array.length - 1; i++) {
attrValues.clear();
// 下标小的在前面的属于后验属性
backValue = array[i].split("=");
preValue = array[i + 1].split("=");
attrValues.add(preValue);
attrValues.add(backValue);
// 算出此种情况的概率值
temp = queryConditionPro(attrValues);
// 进行积的相乘
pro *= temp;
}
return pro;
}
/**
* 验证事件的查询因果关系是否满足贝叶斯网络
*
* @param queryStr
* 查询字符串
* @return
*/
private boolean satisfiedNewWork(String queryStr) {
String attrName;
String[] array;
boolean isExist;
boolean isSatisfied;
// 当前节点
Node currentNode;
// 候选节点列表
ArrayList nodeList;
isSatisfied = true;
currentNode = null;
// 做查询字符串的分解
array = queryStr.split(",");
nodeList = this.nodes;
for (String s : array) {
// 开始时默认属性对应的节点不存在
isExist = false;
// 得到属性事件名
attrName = s.split("=")[0];
for (Node n : nodeList) {
if (n.name.equals(attrName)) {
isExist = true;
currentNode = n;
// 下一轮的候选节点为当前节点的孩子节点
nodeList = currentNode.childNodes;
break;
}
}
// 如果存在未找到的节点,则说明不满足依赖结构跳出循环
if (!isExist) {
isSatisfied = false;
break;
}
}
return isSatisfied;
}
}
场景测试类Client.java:
package DataMining_BayesNetwork;
import java.text.MessageFormat;
/**
* 贝叶斯网络场景测试类
*
* @author lyq
*
*/
public class Client {
public static void main(String[] args) {
String dataFilePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
String attachFilePath = "C:\\Users\\lyq\\Desktop\\icon\\attach.txt";
// 查询串语句
String queryStr;
// 结果概率
double result;
// 查询语句的描述的事件是地震发生了,导致响铃响了,导致接到Mary的电话
queryStr = "E=y,A=y,M=y";
BayesNetWorkTool tool = new BayesNetWorkTool(dataFilePath,
attachFilePath);
result = tool.calProByNetWork(queryStr);
if (result == -1) {
System.out.println("所描述的事件不满足贝叶斯网络的结构,无法求其概率");
} else {
System.out.println(String.format("事件%s发生的概率为%s", queryStr, result));
}
}
}
输出结果:
事件E=y,A=y,M=y发生的概率为0.005373075715453122
参考文献
百度百科
http://www.cnblogs.com/leoo2sk/archive/2010/09/18/bayes-network.html
更多数据挖掘算法
https://github.com/linyiqun/DataMiningAlgorithm
作者:Androidlushangderen 发表于2015/6/29 16:38:45 原文链接
阅读:626 评论:0 查看评论
ACO蚁群算法解决TSP旅行商问题
2015年4月30日 15:31
前言
蚁群算法也是一种利用了大自然规律的启发式算法,与之前学习过的GA遗传算法类似,遗传算法是用了生物进行理论,把更具适应性的基因传给下一代,最后就能得到一个最优解,常常用来寻找问题的最优解。当然,本篇文章不会主讲GA算法的,想要了解的同学可以查看,我的遗传算法学习和遗传算法在走迷宫中的应用。话题重新回到蚁群算法,蚁群算法是一个利用了蚂蚁寻找食物的原理。不知道小时候有没有发现,当一个蚂蚁发现了地上的食物,然后非常迅速的,就有其他的蚂蚁聚拢过来,最后把食物抬回家,这里面其实有着非常多的道理的,在ACO中就用到了这个机理用于解决实际生活中的一些问题。
蚂蚁找食物
首先我们要具体说说一个有意思的事情,就是蚂蚁找食物的问题,理解了这个原理之后,对于理解ACO算法就非常容易了。蚂蚁作为那么小的动物,在地上漫无目的的寻找食物,起初都是没有目标的,他从蚂蚁洞中走出,随机的爬向各个方向,在这期间他会向外界播撒一种化学物质,姑且就叫做信息素,所以这里就可以得到的一个前提,越多蚂蚁走过的路径,信息素浓度就会越高,那么某条路径信息素浓度高了,自然就会有越多的蚂蚁感觉到了,就会聚集过来了。所以当众多蚂蚁中的一个找到食物之后,他就会在走过的路径中放出信息素浓度,因此就会有很多的蚂蚁赶来了。类似下面的场景:
至于蚂蚁是如何感知这个信息素,这个就得问生物学家了,我也没做过研究。
算法介绍
OK,有了上面这个自然生活中的生物场景之后,我们再来切入文章主题来学习一下蚁群算法,百度百科中对应蚁群算法是这么介绍的:蚁群算法是一种在图中寻找优化路径的机率型算法。他的灵感就是来自于蚂蚁发现食物的行为。蚁群算法是一种新的模拟进化优化的算法,与遗传算法有很多相似的地方。蚁群算法在比较早的时候成功解决了TSP旅行商的问题(在后面的例子中也会以这个例子)。要用算法去模拟蚂蚁的这种行为,关键在于信息素的在算法中的设计,以及路径中信息素浓度越大的路径,将会有更高的概率被蚂蚁所选择到。
算法原理
要想实现上面的几个模拟行为,需要借助几个公式,当然公式不是我自己定义的,主要有3个,如下图:
上图中所出现的alpha,beita,p等数字都是控制因子,所以可不必理会,Tij(n)的意思是在时间为n的时候,从城市i到城市j的路径的信息素浓度。类似于nij的字母是城市i到城市j距离的倒数。就是下面这个公式。
所以所有的公式都是为第一个公式服务的,第一个公式的意思是指第k只蚂蚁选择从城市i到城市j的概率,可以见得,这个受距离和信息素浓度的双重影响,距离越远,去此城市的概率自然也低,所以nij会等于距离的倒数,而且在算信息素浓度的时候,也考虑到了信息素浓度衰减的问题,所以会在上次的浓度值上乘以一个衰减因子P。另外还要加上本轮搜索增加的信息素浓度(假如有蚂蚁经过此路径的话),所以这几个公式的整体设计思想还是非常棒的。
算法的代码实现
由于本身我这里没有什么真实的测试数据,就随便自己构造了一个简单的数据,输入如下,分为城市名称和城市之间的距离,用#符号做区分标识,大家应该可以看得懂吧
# CityName
1
2
3
4
# Distance
1 2 1
1 3 1.4
1 4 1
2 3 1
2 4 1
3 4 1
蚂蚁类Ant.java:
package DataMining_ACO;
import java.util.ArrayList;
/**
* 蚂蚁类,进行路径搜索的载体
*
* @author lyq
*
*/
public class Ant implements Comparable {
// 蚂蚁当前所在城市
String currentPos;
// 蚂蚁遍历完回到原点所用的总距离
Double sumDistance;
// 城市间的信息素浓度矩阵,随着时间的增多而减少
double[][] pheromoneMatrix;
// 蚂蚁已经走过的城市集合
ArrayList visitedCitys;
// 还未走过的城市集合
ArrayList nonVisitedCitys;
// 蚂蚁当前走过的路径
ArrayList currentPath;
public Ant(double[][] pheromoneMatrix, ArrayList nonVisitedCitys) {
this.pheromoneMatrix = pheromoneMatrix;
this.nonVisitedCitys = nonVisitedCitys;
this.visitedCitys = new ArrayList<>();
this.currentPath = new ArrayList<>();
}
/**
* 计算路径的总成本(距离)
*
* @return
*/
public double calSumDistance() {
sumDistance = 0.0;
String lastCity;
String currentCity;
for (int i = 0; i < currentPath.size() - 1; i++) {
lastCity = currentPath.get(i);
currentCity = currentPath.get(i + 1);
// 通过距离矩阵进行计算
sumDistance += ACOTool.disMatrix[Integer.parseInt(lastCity)][Integer
.parseInt(currentCity)];
}
return sumDistance;
}
/**
* 蚂蚁选择前往下一个城市
*
* @param city
* 所选的城市
*/
public void goToNextCity(String city) {
this.currentPath.add(city);
this.currentPos = city;
this.nonVisitedCitys.remove(city);
this.visitedCitys.add(city);
}
/**
* 判断蚂蚁是否已经又重新回到起点
*
* @return
*/
public boolean isBack() {
boolean isBack = false;
String startPos;
String endPos;
if (currentPath.size() == 0) {
return isBack;
}
startPos = currentPath.get(0);
endPos = currentPath.get(currentPath.size() - 1);
if (currentPath.size() > 1 && startPos.equals(endPos)) {
isBack = true;
}
return isBack;
}
/**
* 判断蚂蚁在本次的走过的路径中是否包含从城市i到城市j
*
* @param cityI
* 城市I
* @param cityJ
* 城市J
* @return
*/
public boolean pathContained(String cityI, String cityJ) {
String lastCity;
String currentCity;
boolean isContained = false;
for (int i = 0; i < currentPath.size() - 1; i++) {
lastCity = currentPath.get(i);
currentCity = currentPath.get(i + 1);
// 如果某一段路径的始末位置一致,则认为有经过此城市
if ((lastCity.equals(cityI) && currentCity.equals(cityJ))
|| (lastCity.equals(cityJ) && currentCity.equals(cityI))) {
isContained = true;
break;
}
}
return isContained;
}
@Override
public int compareTo(Ant o) {
// TODO Auto-generated method stub
return this.sumDistance.compareTo(o.sumDistance);
}
}
蚁群算法工具类ACOTool.java:
package DataMining_ACO;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* 蚁群算法工具类
*
* @author lyq
*
*/
public class ACOTool {
// 输入数据类型
public static final int INPUT_CITY_NAME = 1;
public static final int INPUT_CITY_DIS = 2;
// 城市间距离邻接矩阵
public static double[][] disMatrix;
// 当前时间
public static int currentTime;
// 测试数据地址
private String filePath;
// 蚂蚁数量
private int antNum;
// 控制参数
private double alpha;
private double beita;
private double p;
private double Q;
// 随机数产生器
private Random random;
// 城市名称集合,这里为了方便,将城市用数字表示
private ArrayList totalCitys;
// 所有的蚂蚁集合
private ArrayList totalAnts;
// 城市间的信息素浓度矩阵,随着时间的增多而减少
private double[][] pheromoneMatrix;
// 目标的最短路径,顺序为从集合的前部往后挪动
private ArrayList bestPath;
// 信息素矩阵存储图,key采用的格式(i,j,t)->value
private Map pheromoneTimeMap;
public ACOTool(String filePath, int antNum, double alpha, double beita,
double p, double Q) {
this.filePath = filePath;
this.antNum = antNum;
this.alpha = alpha;
this.beita = beita;
this.p = p;
this.Q = Q;
this.currentTime = 0;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
int flag = -1;
int src = 0;
int des = 0;
int size = 0;
// 进行城市名称种数的统计
this.totalCitys = new ArrayList<>();
for (String[] array : dataArray) {
if (array[0].equals("#") && totalCitys.size() == 0) {
flag = INPUT_CITY_NAME;
continue;
} else if (array[0].equals("#") && totalCitys.size() > 0) {
size = totalCitys.size();
// 初始化距离矩阵
this.disMatrix = new double[size + 1][size + 1];
this.pheromoneMatrix = new double[size + 1][size + 1];
// 初始值-1代表此对应位置无值
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
this.disMatrix[i][j] = -1;
this.pheromoneMatrix[i][j] = -1;
}
}
flag = INPUT_CITY_DIS;
continue;
}
if (flag == INPUT_CITY_NAME) {
this.totalCitys.add(array[0]);
} else {
src = Integer.parseInt(array[0]);
des = Integer.parseInt(array[1]);
this.disMatrix[src][des] = Double.parseDouble(array[2]);
this.disMatrix[des][src] = Double.parseDouble(array[2]);
}
}
}
/**
* 计算从蚂蚁城市i到j的概率
*
* @param cityI
* 城市I
* @param cityJ
* 城市J
* @param currentTime
* 当前时间
* @return
*/
private double calIToJProbably(String cityI, String cityJ, int currentTime) {
double pro = 0;
double n = 0;
double pheromone;
int i;
int j;
i = Integer.parseInt(cityI);
j = Integer.parseInt(cityJ);
pheromone = getPheromone(currentTime, cityI, cityJ);
n = 1.0 / disMatrix[i][j];
if (pheromone == 0) {
pheromone = 1;
}
pro = Math.pow(n, alpha) * Math.pow(pheromone, beita);
return pro;
}
/**
* 计算综合概率蚂蚁从I城市走到J城市的概率
*
* @return
*/
public String selectAntNextCity(Ant ant, int currentTime) {
double randomNum;
double tempPro;
// 总概率指数
double proTotal;
String nextCity = null;
ArrayList allowedCitys;
// 各城市概率集
double[] proArray;
// 如果是刚刚开始的时候,没有路过任何城市,则随机返回一个城市
if (ant.currentPath.size() == 0) {
nextCity = String.valueOf(random.nextInt(totalCitys.size()) + 1);
return nextCity;
} else if (ant.nonVisitedCitys.isEmpty()) {
// 如果全部遍历完毕,则再次回到起点
nextCity = ant.currentPath.get(0);
return nextCity;
}
proTotal = 0;
allowedCitys = ant.nonVisitedCitys;
proArray = new double[allowedCitys.size()];
for (int i = 0; i < allowedCitys.size(); i++) {
nextCity = allowedCitys.get(i);
proArray[i] = calIToJProbably(ant.currentPos, nextCity, currentTime);
proTotal += proArray[i];
}
for (int i = 0; i < allowedCitys.size(); i++) {
// 归一化处理
proArray[i] /= proTotal;
}
// 用随机数选择下一个城市
randomNum = random.nextInt(100) + 1;
randomNum = randomNum / 100;
// 因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断
if (randomNum == 1) {
randomNum = randomNum - 0.01;
}
tempPro = 0;
// 确定区间
for (int j = 0; j < allowedCitys.size(); j++) {
if (randomNum > tempPro && randomNum <= tempPro + proArray[j]) {
// 采用拷贝的方式避免引用重复
nextCity = allowedCitys.get(j);
break;
} else {
tempPro += proArray[j];
}
}
return nextCity;
}
/**
* 获取给定时间点上从城市i到城市j的信息素浓度
*
* @param t
* @param cityI
* @param cityJ
* @return
*/
private double getPheromone(int t, String cityI, String cityJ) {
double pheromone = 0;
String key;
// 上一周期需将时间倒回一周期
key = MessageFormat.format("{0},{1},{2}", cityI, cityJ, t);
if (pheromoneTimeMap.containsKey(key)) {
pheromone = pheromoneTimeMap.get(key);
}
return pheromone;
}
/**
* 每轮结束,刷新信息素浓度矩阵
*
* @param t
*/
private void refreshPheromone(int t) {
double pheromone = 0;
// 上一轮周期结束后的信息素浓度,丛信息素浓度图中查找
double lastTimeP = 0;
// 本轮信息素浓度增加量
double addPheromone;
String key;
for (String i : totalCitys) {
for (String j : totalCitys) {
if (!i.equals(j)) {
// 上一周期需将时间倒回一周期
key = MessageFormat.format("{0},{1},{2}", i, j, t - 1);
if (pheromoneTimeMap.containsKey(key)) {
lastTimeP = pheromoneTimeMap.get(key);
} else {
lastTimeP = 0;
}
addPheromone = 0;
for (Ant ant : totalAnts) {
if(ant.pathContained(i, j)){
// 每只蚂蚁传播的信息素为控制因子除以距离总成本
addPheromone += Q / ant.calSumDistance();
}
}
// 将上次的结果值加上递增的量,并存入图中
pheromone = p * lastTimeP + addPheromone;
key = MessageFormat.format("{0},{1},{2}", i, j, t);
pheromoneTimeMap.put(key, pheromone);
}
}
}
}
/**
* 蚁群算法迭代次数
* @param loopCount
* 具体遍历次数
*/
public void antStartSearching(int loopCount) {
// 蚁群寻找的总次数
int count = 0;
// 选中的下一个城市
String selectedCity = "";
pheromoneTimeMap = new HashMap();
totalAnts = new ArrayList<>();
random = new Random();
while (count < loopCount) {
initAnts();
while (true) {
for (Ant ant : totalAnts) {
selectedCity = selectAntNextCity(ant, currentTime);
ant.goToNextCity(selectedCity);
}
// 如果已经遍历完所有城市,则跳出此轮循环
if (totalAnts.get(0).isBack()) {
break;
}
}
// 周期时间叠加
currentTime++;
refreshPheromone(currentTime);
count++;
}
// 根据距离成本,选出所花距离最短的一个路径
Collections.sort(totalAnts);
bestPath = totalAnts.get(0).currentPath;
System.out.println(MessageFormat.format("经过{0}次循环遍历,最终得出的最佳路径:", count));
System.out.print("entrance");
for (String cityName : bestPath) {
System.out.print(MessageFormat.format("-->{0}", cityName));
}
}
/**
* 初始化蚁群操作
*/
private void initAnts() {
Ant tempAnt;
ArrayList nonVisitedCitys;
totalAnts.clear();
// 初始化蚁群
for (int i = 0; i < antNum; i++) {
nonVisitedCitys = (ArrayList) totalCitys.clone();
tempAnt = new Ant(pheromoneMatrix, nonVisitedCitys);
totalAnts.add(tempAnt);
}
}
}
场景测试类Client.java:
package DataMining_ACO;
/**
* 蚁群算法测试类
* @author lyq
*
*/
public class Client {
public static void main(String[] args){
//测试数据
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
//蚂蚁数量
int antNum;
//蚁群算法迭代次数
int loopCount;
//控制参数
double alpha;
double beita;
double p;
double Q;
antNum = 3;
alpha = 0.5;
beita = 1;
p = 0.5;
Q = 5;
loopCount = 5;
ACOTool tool = new ACOTool(filePath, antNum, alpha, beita, p, Q);
tool.antStartSearching(loopCount);
}
}
算法的输出,就是在多次搜索之后,找到的路径中最短的一个路径:
经过5次循环遍历,最终得出的最佳路径:
entrance-->4-->1-->2-->3-->4
因为数据量比较小,并不能看出蚁群算法在这方面的优势,博友们可以再次基础上自行改造,并用大一点的数据做测试,其中的4个控制因子也可以调控。蚁群算法作为一种启发式算法,还可以和遗传算法结合,创造出更优的算法。蚁群算法可以解决许多这样的连通图路径优化问题。但是有的时候也会出现搜索时间过长的问题。
参考文献:百度百科.蚁群算法
我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
作者:Androidlushangderen 发表于2015/4/30 15:31:45 原文链接
阅读:925 评论:0 查看评论
从Apriori到MS-Apriori算法
2015年4月16日 22:42
前言
最近的几个月一直在研究和学习各种经典的DM,机器学习的相关算法,收获还是挺多的,另外还整了一个DM算法库,集成了很多数据挖掘算法,放在了我的github上,博友的热度超出我的想象,有很多人给我点了star,在此感谢各大博友们,我将会继续更新我的DM算法库。也许这些算法还不能直接拿来用,但是可以给你提供思路,或变变数据的输入格式就能用了。好,扯得有点远了,现在说正题,本篇文章重新回到讲述Apriori算法,当然我这次不会讲之前说过的Apriori,(比较老套的那些东西网上也很多,我分析的也不一定是最好),本文的主题是Apriori算法的升级版本算法--MS-Apriori。在前面加了Ms,是什么意思呢,他可不是升级的意思,Ms是Mis的缩写,MIS的全称是Min Item Support,最小项目支持度。这有何Apriori算法有什么关系呢,在后面的正文中,我会主要解释这是什么意思,其实这只是其中的一个小的点,Ms-Apriori还是有很多对于Apriori算法的改进的。
Apriori
在了解Ms-Apriori算法之前,还是有必要重新回顾一下Apriori算法,Apriori算法是一种演绎算法,后一次的结果是依赖于上一次的计算结果的,算法的目的就是通过给定的数据挖掘出其中的频繁项,进而推导出关联规则,属于模式挖掘的范畴。Apriori算法的核心步骤可以概括为2个过程,1个是连接运算,1个剪枝运算,这具体的过程就不详细说了,如果想要了解的话,请点击我的Apriori算法分析。尽管Apriori算法在一定的程度上看起来非常的好用,但是他并不是十全十美的,首先在选择的类型上就存在限制,他无法照顾到不同类型的频繁项的挖掘。比如说一些稀有项目的挖掘,比如部分奢侈品。换句话说,如果最小支持度设置的过大,就会导致这些稀有的项集就很难被发现,于是我们就想把这个最小支持度值调得足够小不久OK了吗,事实并非这么简单,支持度调小的话,会造成巨大量的频繁项集合候选项的产生,同时会有大量的一些无关的关联规则被推导出来,当然这个问题就是ms-apriori所要解决的主要问题。下面看看ms-apropri给出了怎么样的解决办法。
Ms-Apriori
Ms-Apriori算法采用另外一种办法,既然统一的支持度值不能兼顾所有的情况,那我可以设置多个支持度值啊,每个种类项都有一个最小支持度阈值,然后一个频繁项的最小支持度阈值取其中项集元素中的最小支持度值作为该项集的最小支持度值。这样的话,如果一个频繁项中出现了稀有项集,则这个项集的最小支持度值就会被拉低,如果又有某个项都是出现频率很高的项构成的话,则支持度阈值又会被拉高。当然,如果出现了一个比较难变态的情况就是,频繁项中同时出现了稀有项和普通项,我们可以通过设置SDC支持度差别限制来限制这种情况的发生,使之挖掘的频繁项更加的合理。通过这里的描述,你就可以发现,当mis最小支持度阈值数组的个数只有1个的时候,ms-apriori算法就退化成了Apriori算法了。
其实ms-apriori算法在某些细节的处理上也有对原先的算法做过一定的优化,这里提及2点。
1、每个候选项的支持度值的统计
原先Apriori算法的操作是扫描整个数据集,进行计数的统计,说白了就是线性扫描一遍,效率自不必说,但是如果你自己思考,其实每次的统计的结果一定不会超过他的上一次子集的结果值,因为他是从上一次的计算过程演绎而来的,当前项集的结果是包含了子项集的结果的,所以改进的算法是每次从子集的计数量中再次计算支持度值,具体操作详见后面我的代码实现,效率还是提高了不少。
2、第二是关联规则的推导
找到了所有的频繁项,找出其中的关联规则最笨的办法就是一个个去算置信度,然后输出满足要求条件的规则,但是其实这里面也包含有上条规则中类似的特点,举个例子,如果已经有一条规则,{1}-->{2, 3, 4},代表在存在1的情况下能退出2,3,4,的存在,那么我们就一定能退出{1, 2}--->{3, 4},因为这是后者的情况其实是被包含于前者的情况的,如果你还不能理解,代入置信度计算的公式,分子相同的情况下,{1,2}发生的情况数一定小于或等于{1}的情况,于是整个置信度必定{1,2}的大于{1}的情况。
关联规则挖掘的数据格式
这里再随便说说关联规则的数据格式,也许在很多书中,用于进行Apriori这类算法的测试的数据都是事务型的数据,其实不是的关系表型的数据同样可以做关联规则的挖掘,不过这需要经过一步预处理的方式,让机器能够更好的识别,推荐一种常见的做法,就是采用属性名+属性值的方式,单单用属性值是不够的,因为属性值是在不同的属性中可能会有重,这点在CBA(基于关联规则分类算法)中也提到过一些,具体的可以查阅CBA基于关联规则分类。
MS-Apriori算法的代码实现
算法的测试我采用了2种类型数据做测试一种是事务型数据,一种是非事务型的数据,输入分别如下:
input.txt:
T1 1 2 5
T2 2 4
T3 2 3
T4 1 2 4
T5 1 3
T6 2 3
T7 1 3
T8 1 2 3 5
T9 1 2 3
input2.txt
Rid Age Income Student CreditRating BuysComputer
1 Youth High No Fair No
2 Youth High No Excellent No
3 MiddleAged High No Fair Yes
4 Senior Medium No Fair Yes
5 Senior Low Yes Fair Yes
6 Senior Low Yes Excellent No
7 MiddleAged Low Yes Excellent Yes
8 Youth Medium No Fair No
9 Youth Low Yes Fair Yes
10 Senior Medium Yes Fair Yes
11 Youth Medium Yes Excellent Yes
12 MiddleAged Medium No Excellent Yes
13 MiddleAged High Yes Fair Yes
14 Senior Medium No Excellent No
算法工具类MSAprioriTool.java:
package DataMining_MSApriori;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import DataMining_Apriori.FrequentItem;
/**
* 基于多支持度的Apriori算法工具类
*
* @author lyq
*
*/
public class MSAprioriTool {
// 前件判断的结果值,用于关联规则的推导
public static final int PREFIX_NOT_SUB = -1;
public static final int PREFIX_EQUAL = 1;
public static final int PREFIX_IS_SUB = 2;
// 是否读取的是事务型数据
private boolean isTransaction;
// 最大频繁k项集的k值
private int initFItemNum;
// 事务数据文件地址
private String filePath;
// 最小支持度阈值
private double minSup;
// 最小置信度率
private double minConf;
// 最大支持度差别阈值
private double delta;
// 多项目的最小支持度数,括号中的下标代表的是商品的ID
private double[] mis;
// 每个事务中的商品ID
private ArrayList totalGoodsIDs;
// 关系表数据所转化的事务数据
private ArrayList transactionDatas;
// 过程中计算出来的所有频繁项集列表
private ArrayList resultItem;
// 过程中计算出来频繁项集的ID集合
private ArrayList resultItemID;
// 属性到数字的映射图
private HashMap attr2Num;
// 数字id对应属性的映射图
private HashMap num2Attr;
// 频繁项集所覆盖的id数值
private Map fItem2Id;
/**
* 事务型数据关联挖掘算法
*
* @param filePath
* @param minConf
* @param delta
* @param mis
* @param isTransaction
*/
public MSAprioriTool(String filePath, double minConf, double delta,
double[] mis, boolean isTransaction) {
this.filePath = filePath;
this.minConf = minConf;
this.delta = delta;
this.mis = mis;
this.isTransaction = isTransaction;
this.fItem2Id = new HashMap<>();
readDataFile();
}
/**
* 非事务型关联挖掘
*
* @param filePath
* @param minConf
* @param minSup
* @param isTransaction
*/
public MSAprioriTool(String filePath, double minConf, double minSup,
boolean isTransaction) {
this.filePath = filePath;
this.minConf = minConf;
this.minSup = minSup;
this.isTransaction = isTransaction;
this.delta = 1.0;
this.fItem2Id = new HashMap<>();
readRDBMSData(filePath);
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
String[] temp = null;
ArrayList dataArray;
dataArray = readLine(filePath);
totalGoodsIDs = new ArrayList<>();
for (String[] array : dataArray) {
temp = new String[array.length - 1];
System.arraycopy(array, 1, temp, 0, array.length - 1);
// 将事务ID加入列表吧中
totalGoodsIDs.add(temp);
}
}
/**
* 从文件中逐行读数据
*
* @param filePath
* 数据文件地址
* @return
*/
private ArrayList readLine(String filePath) {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
return dataArray;
}
/**
* 计算频繁项集
*/
public void calFItems() {
FrequentItem fItem;
computeLink();
printFItems();
if (isTransaction) {
fItem = resultItem.get(resultItem.size() - 1);
// 取出最后一个频繁项集做关联规则的推导
System.out.println("最后一个频繁项集做关联规则的推导结果:");
printAttachRuls(fItem.getIdArray());
}
}
/**
* 输出频繁项集
*/
private void printFItems() {
if (isTransaction) {
System.out.println("事务型数据频繁项集输出结果:");
} else {
System.out.println("非事务(关系)型数据频繁项集输出结果:");
}
// 输出频繁项集
for (int k = 1; k <= initFItemNum; k++) {
System.out.println("频繁" + k + "项集:");
for (FrequentItem i : resultItem) {
if (i.getLength() == k) {
System.out.print("{");
for (String t : i.getIdArray()) {
if (!isTransaction) {
// 如果原本是非事务型数据,需要重新做替换
t = num2Attr.get(Integer.parseInt(t));
}
System.out.print(t + ",");
}
System.out.print("},");
}
}
System.out.println();
}
}
/**
* 项集进行连接运算
*/
private void computeLink() {
// 连接计算的终止数,k项集必须算到k-1子项集为止
int endNum = 0;
// 当前已经进行连接运算到几项集,开始时就是1项集
int currentNum = 1;
// 商品,1频繁项集映射图
HashMap itemMap = new HashMap<>();
FrequentItem tempItem;
// 初始列表
ArrayList list = new ArrayList<>();
// 经过连接运算后产生的结果项集
resultItem = new ArrayList<>();
resultItemID = new ArrayList<>();
// 商品ID的种类
ArrayList idType = new ArrayList<>();
for (String[] a : totalGoodsIDs) {
for (String s : a) {
if (!idType.contains(s)) {
tempItem = new FrequentItem(new String[] { s }, 1);
idType.add(s);
resultItemID.add(new String[] { s });
} else {
// 支持度计数加1
tempItem = itemMap.get(s);
tempItem.setCount(tempItem.getCount() + 1);
}
itemMap.put(s, tempItem);
}
}
// 将初始频繁项集转入到列表中,以便继续做连接运算
for (Map.Entry entry : itemMap.entrySet()) {
tempItem = entry.getValue();
// 判断1频繁项集是否满足支持度阈值的条件
if (judgeFItem(tempItem.getIdArray())) {
list.add(tempItem);
}
}
// 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
Collections.sort(list);
resultItem.addAll(list);
String[] array1;
String[] array2;
String[] resultArray;
ArrayList tempIds;
ArrayList resultContainer;
// 总共要算到endNum项集
endNum = list.size() - 1;
initFItemNum = list.size() - 1;
while (currentNum < endNum) {
resultContainer = new ArrayList<>();
for (int i = 0; i < list.size() - 1; i++) {
tempItem = list.get(i);
array1 = tempItem.getIdArray();
for (int j = i + 1; j < list.size(); j++) {
tempIds = new ArrayList<>();
array2 = list.get(j).getIdArray();
for (int k = 0; k < array1.length; k++) {
// 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
if (array1[k].equals(array2[k])) {
tempIds.add(array1[k]);
} else {
tempIds.add(array1[k]);
tempIds.add(array2[k]);
}
}
resultArray = new String[tempIds.size()];
tempIds.toArray(resultArray);
boolean isContain = false;
// 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
if (resultArray.length == (array1.length + 1)) {
isContain = isIDArrayContains(resultContainer,
resultArray);
if (!isContain) {
resultContainer.add(resultArray);
}
}
}
}
// 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
list = cutItem(resultContainer);
currentNum++;
}
}
/**
* 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
*/
private ArrayList cutItem(ArrayList resultIds) {
String[] temp;
// 忽略的索引位置,以此构建子集
int igNoreIndex = 0;
FrequentItem tempItem;
// 剪枝生成新的频繁项集
ArrayList newItem = new ArrayList<>();
// 不符合要求的id
ArrayList deleteIdArray = new ArrayList<>();
// 子项集是否也为频繁子项集
boolean isContain = true;
for (String[] array : resultIds) {
// 列举出其中的一个个的子项集,判断存在于频繁项集列表中
temp = new String[array.length - 1];
for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) {
isContain = true;
for (int j = 0, k = 0; j < array.length; j++) {
if (j != igNoreIndex) {
temp[k] = array[j];
k++;
}
}
if (!isIDArrayContains(resultItemID, temp)) {
isContain = false;
break;
}
}
if (!isContain) {
deleteIdArray.add(array);
}
}
// 移除不符合条件的ID组合
resultIds.removeAll(deleteIdArray);
// 移除支持度计数不够的id集合
int tempCount = 0;
boolean isSatisfied = false;
for (String[] array : resultIds) {
isSatisfied = judgeFItem(array);
// 如果此频繁项集满足多支持度阈值限制条件和支持度差别限制条件,则添加入结果集中
if (isSatisfied) {
tempItem = new FrequentItem(array, tempCount);
newItem.add(tempItem);
resultItemID.add(array);
resultItem.add(tempItem);
}
}
return newItem;
}
/**
* 判断列表结果中是否已经包含此数组
*
* @param container
* ID数组容器
* @param array
* 待比较数组
* @return
*/
private boolean isIDArrayContains(ArrayList container,
String[] array) {
boolean isContain = true;
if (container.size() == 0) {
isContain = false;
return isContain;
}
for (String[] s : container) {
// 比较的视乎必须保证长度一样
if (s.length != array.length) {
continue;
}
isContain = true;
for (int i = 0; i < s.length; i++) {
// 只要有一个id不等,就算不相等
if (s[i] != array[i]) {
isContain = false;
break;
}
}
// 如果已经判断是包含在容器中时,直接退出
if (isContain) {
break;
}
}
return isContain;
}
/**
* 判断一个频繁项集是否满足条件
*
* @param frequentItem
* 待判断频繁项集
* @return
*/
private boolean judgeFItem(String[] frequentItem) {
boolean isSatisfied = true;
int id;
int count;
double tempMinSup;
// 最小的支持度阈值
double minMis = Integer.MAX_VALUE;
// 最大的支持度阈值
double maxMis = -Integer.MAX_VALUE;
// 如果是事务型数据,用mis数组判断,如果不是统一用同样的最小支持度阈值判断
if (isTransaction) {
// 寻找频繁项集中的最小支持度阈值
for (int i = 0; i < frequentItem.length; i++) {
id = i + 1;
if (mis[id] < minMis) {
minMis = mis[id];
}
if (mis[id] > maxMis) {
maxMis = mis[id];
}
}
} else {
minMis = minSup;
maxMis = minSup;
}
count = calSupportCount(frequentItem);
tempMinSup = 1.0 * count / totalGoodsIDs.size();
// 判断频繁项集的支持度阈值是否超过最小的支持度阈值
if (tempMinSup < minMis) {
isSatisfied = false;
}
// 如果误差超过了最大支持度差别,也算不满足条件
if (Math.abs(maxMis - minMis) > delta) {
isSatisfied = false;
}
return isSatisfied;
}
/**
* 统计候选频繁项集的支持度数,利用他的子集进行技术,无须扫描整个数据集
*
* @param frequentItem
* 待计算频繁项集
* @return
*/
private int calSupportCount(String[] frequentItem) {
int count = 0;
int[] ids;
String key;
String[] array;
ArrayList newIds;
key = "";
for (int i = 1; i < frequentItem.length; i++) {
key += frequentItem[i];
}
newIds = new ArrayList<>();
// 找出所属的事务ID
ids = fItem2Id.get(key);
// 如果没有找到子项集的事务id,则全盘扫描数据集
if (ids == null || ids.length == 0) {
for (int j = 0; j < totalGoodsIDs.size(); j++) {
array = totalGoodsIDs.get(j);
if (isStrArrayContain(array, frequentItem)) {
count++;
newIds.add(j);
}
}
} else {
for (int index : ids) {
array = totalGoodsIDs.get(index);
if (isStrArrayContain(array, frequentItem)) {
count++;
newIds.add(index);
}
}
}
ids = new int[count];
for (int i = 0; i < ids.length; i++) {
ids[i] = newIds.get(i);
}
key = frequentItem[0] + key;
// 将所求值存入图中,便于下次的计数
fItem2Id.put(key, ids);
return count;
}
/**
* 根据给定的频繁项集输出关联规则
*
* @param frequentItems
* 频繁项集
*/
public void printAttachRuls(String[] frequentItem) {
// 关联规则前件,后件对
Map, ArrayList> rules;
// 前件搜索历史
Map, ArrayList> searchHistory;
ArrayList prefix;
ArrayList suffix;
rules = new HashMap, ArrayList>();
searchHistory = new HashMap<>();
for (int i = 0; i < frequentItem.length; i++) {
suffix = new ArrayList<>();
for (int j = 0; j < frequentItem.length; j++) {
suffix.add(frequentItem[j]);
}
prefix = new ArrayList<>();
recusiveFindRules(rules, searchHistory, prefix, suffix);
}
// 依次输出找到的关联规则
for (Map.Entry, ArrayList> entry : rules
.entrySet()) {
prefix = entry.getKey();
suffix = entry.getValue();
printRuleDetail(prefix, suffix);
}
}
/**
* 根据前件后件,输出关联规则
*
* @param prefix
* @param suffix
*/
private void printRuleDetail(ArrayList prefix,
ArrayList suffix) {
// {A}-->{B}的意思为在A的情况下发生B的概率
System.out.print("{");
for (String s : prefix) {
System.out.print(s + ", ");
}
System.out.print("}-->");
System.out.print("{");
for (String s : suffix) {
System.out.print(s + ", ");
}
System.out.println("}");
}
/**
* 递归扩展关联规则解
*
* @param rules
* 关联规则结果集
* @param history
* 前件搜索历史
* @param prefix
* 关联规则前件
* @param suffix
* 关联规则后件
*/
private void recusiveFindRules(
Map, ArrayList> rules,
Map, ArrayList> history,
ArrayList prefix, ArrayList suffix) {
int count1;
int count2;
int compareResult;
// 置信度大小
double conf;
String[] temp1;
String[] temp2;
ArrayList copyPrefix;
ArrayList copySuffix;
// 如果后件只有1个,则函数返回
if (suffix.size() == 1) {
return;
}
for (String s : suffix) {
count1 = 0;
count2 = 0;
copyPrefix = (ArrayList) prefix.clone();
copyPrefix.add(s);
copySuffix = (ArrayList) suffix.clone();
// 将拷贝的后件移除添加的一项
copySuffix.remove(s);
compareResult = isSubSetInRules(history, copyPrefix);
if (compareResult == PREFIX_EQUAL) {
// 如果曾经已经被搜索过,则跳过
continue;
}
// 判断是否为子集,如果是子集则无需计算
compareResult = isSubSetInRules(rules, copyPrefix);
if (compareResult == PREFIX_IS_SUB) {
rules.put(copyPrefix, copySuffix);
// 加入到搜索历史中
history.put(copyPrefix, copySuffix);
recusiveFindRules(rules, history, copyPrefix, copySuffix);
continue;
}
// 暂时合并为总的集合
copySuffix.addAll(copyPrefix);
temp1 = new String[copyPrefix.size()];
temp2 = new String[copySuffix.size()];
copyPrefix.toArray(temp1);
copySuffix.toArray(temp2);
// 之后再次移除之前天剑的前件
copySuffix.removeAll(copyPrefix);
for (String[] a : totalGoodsIDs) {
if (isStrArrayContain(a, temp1)) {
count1++;
// 在group1的条件下,统计group2的事件发生次数
if (isStrArrayContain(a, temp2)) {
count2++;
}
}
}
conf = 1.0 * count2 / count1;
if (conf > minConf) {
// 设置此前件条件下,能导出关联规则
rules.put(copyPrefix, copySuffix);
}
// 加入到搜索历史中
history.put(copyPrefix, copySuffix);
recusiveFindRules(rules, history, copyPrefix, copySuffix);
}
}
/**
* 判断当前的前件是否会关联规则的子集
*
* @param rules
* 当前已经判断出的关联规则
* @param prefix
* 待判断的前件
* @return
*/
private int isSubSetInRules(
Map, ArrayList> rules,
ArrayList prefix) {
int result = PREFIX_NOT_SUB;
String[] temp1;
String[] temp2;
ArrayList tempPrefix;
for (Map.Entry, ArrayList> entry : rules
.entrySet()) {
tempPrefix = entry.getKey();
temp1 = new String[tempPrefix.size()];
temp2 = new String[prefix.size()];
tempPrefix.toArray(temp1);
prefix.toArray(temp2);
// 判断当前构造的前件是否已经是存在前件的子集
if (isStrArrayContain(temp2, temp1)) {
if (temp2.length == temp1.length) {
result = PREFIX_EQUAL;
} else {
result = PREFIX_IS_SUB;
}
}
if (result == PREFIX_EQUAL) {
break;
}
}
return result;
}
/**
* 数组array2是否包含于array1中,不需要完全一样
*
* @param array1
* @param array2
* @return
*/
private boolean isStrArrayContain(String[] array1, String[] array2) {
boolean isContain = true;
for (String s2 : array2) {
isContain = false;
for (String s1 : array1) {
// 只要s2字符存在于array1中,这个字符就算包含在array1中
if (s2.equals(s1)) {
isContain = true;
break;
}
}
// 一旦发现不包含的字符,则array2数组不包含于array1中
if (!isContain) {
break;
}
}
return isContain;
}
/**
* 读关系表中的数据,并转化为事务数据
*
* @param filePath
*/
private void readRDBMSData(String filePath) {
String str;
// 属性名称行
String[] attrNames = null;
String[] temp;
String[] newRecord;
ArrayList datas = null;
datas = readLine(filePath);
// 获取首行
attrNames = datas.get(0);
this.transactionDatas = new ArrayList<>();
// 去除首行数据
for (int i = 1; i < datas.size(); i++) {
temp = datas.get(i);
// 过滤掉首列id列
for (int j = 1; j < temp.length; j++) {
str = "";
// 采用属性名+属性值的形式避免数据的重复
str = attrNames[j] + ":" + temp[j];
temp[j] = str;
}
newRecord = new String[attrNames.length - 1];
System.arraycopy(temp, 1, newRecord, 0, attrNames.length - 1);
this.transactionDatas.add(newRecord);
}
attributeReplace();
// 将事务数转到totalGoodsID中做统一处理
this.totalGoodsIDs = transactionDatas;
}
/**
* 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
*/
private void attributeReplace() {
int currentValue = 1;
String s;
// 属性名到数字的映射图
attr2Num = new HashMap<>();
num2Attr = new HashMap<>();
// 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
for (int j = 0; j < transactionDatas.get(0).length; j++) {
for (int i = 0; i < transactionDatas.size(); i++) {
s = transactionDatas.get(i)[j];
if (!attr2Num.containsKey(s)) {
attr2Num.put(s, currentValue);
num2Attr.put(currentValue, s);
transactionDatas.get(i)[j] = currentValue + "";
currentValue++;
} else {
transactionDatas.get(i)[j] = attr2Num.get(s) + "";
}
}
}
}
}
频繁项集类FrequentItem.java:
package DataMining_MSApriori;
/**
* 频繁项集
*
* @author lyq
*
*/
public class FrequentItem implements Comparable{
// 频繁项集的集合ID
private String[] idArray;
// 频繁项集的支持度计数
private int count;
//频繁项集的长度,1项集或是2项集,亦或是3项集
private int length;
public FrequentItem(String[] idArray, int count){
this.idArray = idArray;
this.count = count;
length = idArray.length;
}
public String[] getIdArray() {
return idArray;
}
public void setIdArray(String[] idArray) {
this.idArray = idArray;
}
public int getCount() {
return count;
}
public void setCount(int count) {
this.count = count;
}
public int getLength() {
return length;
}
public void setLength(int length) {
this.length = length;
}
@Override
public int compareTo(FrequentItem o) {
// TODO Auto-generated method stub
Integer int1 = Integer.parseInt(this.getIdArray()[0]);
Integer int2 = Integer.parseInt(o.getIdArray()[0]);
return int1.compareTo(int2);
}
}
测试类Client.java:
package DataMining_MSApriori;
/**
* 基于多支持度的Apriori算法测试类
* @author lyq
*
*/
public class Client {
public static void main(String[] args){
//是否是事务型数据
boolean isTransaction;
//测试数据文件地址
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
//关系表型数据文件地址
String tableFilePath = "C:\\Users\\lyq\\Desktop\\icon\\input2.txt";
//最小支持度阈值
double minSup;
// 最小置信度率
double minConf;
//最大支持度差别阈值
double delta;
//多项目的最小支持度数,括号中的下标代表的是商品的ID
double[] mis;
//msApriori算法工具类
MSAprioriTool tool;
//为了测试的方便,取一个偏低的置信度值0.3
minConf = 0.3;
minSup = 0.1;
delta = 0.5;
//每项的支持度率都默认为0.1,第一项不使用
mis = new double[]{-1, 0.1, 0.1, 0.1, 0.1, 0.1};
isTransaction = true;
isTransaction = true;
tool = new MSAprioriTool(filePath, minConf, delta, mis, isTransaction);
tool.calFItems();
System.out.println();
isTransaction = false;
//重新初始化数据
tool = new MSAprioriTool(tableFilePath, minConf, minSup, isTransaction);
tool.calFItems();
}
}
算法输出(输出的内容有点多):
事务型数据频繁项集输出结果:
频繁1项集:
{1,},{2,},{3,},{4,},{5,},
频繁2项集:
{1,2,},{1,3,},{1,4,},{1,5,},{2,3,},{2,4,},{2,5,},{3,5,},
频繁3项集:
{1,2,3,},{1,2,4,},{1,2,5,},{1,3,5,},{2,3,5,},
频繁4项集:
{1,2,3,5,},
最后一个频繁项集做关联规则的推导结果:
{2, 5, }-->{1, 3, }
{5, }-->{1, 2, 3, }
{3, 5, }-->{1, 2, }
{1, 5, }-->{2, 3, }
{2, 3, 5, }-->{1, }
{1, 2, 5, }-->{3, }
{1, 3, 5, }-->{2, }
{1, 2, 3, }-->{5, }
非事务(关系)型数据频繁项集输出结果:
频繁1项集:
{Age:Youth,},{Age:MiddleAged,},{Age:Senior,},{Income:High,},{Income:Medium,},{Income:Low,},{Student:No,},{Student:Yes,},{CreditRating:Fair,},{CreditRating:Excellent,},{BuysComputer:No,},{BuysComputer:Yes,},
频繁2项集:
{Age:Youth,Income:High,},{Age:Youth,Income:Medium,},{Age:Youth,Student:No,},{Age:Youth,Student:Yes,},{Age:Youth,CreditRating:Fair,},{Age:Youth,CreditRating:Excellent,},{Age:Youth,BuysComputer:No,},{Age:Youth,BuysComputer:Yes,},{Age:MiddleAged,Income:High,},{Age:MiddleAged,Student:No,},{Age:MiddleAged,Student:Yes,},{Age:MiddleAged,CreditRating:Fair,},{Age:MiddleAged,CreditRating:Excellent,},{Age:MiddleAged,BuysComputer:Yes,},{Age:Senior,Income:Medium,},{Age:Senior,Income:Low,},{Age:Senior,Student:No,},{Age:Senior,Student:Yes,},{Age:Senior,CreditRating:Fair,},{Age:Senior,CreditRating:Excellent,},{Age:Senior,BuysComputer:No,},{Age:Senior,BuysComputer:Yes,},{Income:High,Student:No,},{Income:High,CreditRating:Fair,},{Income:High,BuysComputer:No,},{Income:High,BuysComputer:Yes,},{Income:Medium,Student:No,},{Income:Medium,Student:Yes,},{Income:Medium,CreditRating:Fair,},{Income:Medium,CreditRating:Excellent,},{Income:Medium,BuysComputer:No,},{Income:Medium,BuysComputer:Yes,},{Income:Low,Student:Yes,},{Income:Low,CreditRating:Fair,},{Income:Low,CreditRating:Excellent,},{Income:Low,BuysComputer:Yes,},{Student:No,CreditRating:Fair,},{Student:No,CreditRating:Excellent,},{Student:No,BuysComputer:No,},{Student:No,BuysComputer:Yes,},{Student:Yes,CreditRating:Fair,},{Student:Yes,CreditRating:Excellent,},{Student:Yes,BuysComputer:Yes,},{CreditRating:Fair,BuysComputer:No,},{CreditRating:Fair,BuysComputer:Yes,},{CreditRating:Excellent,BuysComputer:No,},{CreditRating:Excellent,BuysComputer:Yes,},
频繁3项集:
{Age:Youth,Income:High,Student:No,},{Age:Youth,Income:High,BuysComputer:No,},{Age:Youth,Student:No,CreditRating:Fair,},{Age:Youth,Student:No,BuysComputer:No,},{Age:Youth,Student:Yes,BuysComputer:Yes,},{Age:Youth,CreditRating:Fair,BuysComputer:No,},{Age:MiddleAged,Income:High,CreditRating:Fair,},{Age:MiddleAged,Income:High,BuysComputer:Yes,},{Age:MiddleAged,Student:No,BuysComputer:Yes,},{Age:MiddleAged,Student:Yes,BuysComputer:Yes,},{Age:MiddleAged,CreditRating:Fair,BuysComputer:Yes,},{Age:MiddleAged,CreditRating:Excellent,BuysComputer:Yes,},{Age:Senior,Income:Medium,Student:No,},{Age:Senior,Income:Medium,CreditRating:Fair,},{Age:Senior,Income:Medium,BuysComputer:Yes,},{Age:Senior,Income:Low,Student:Yes,},{Age:Senior,Student:Yes,CreditRating:Fair,},{Age:Senior,Student:Yes,BuysComputer:Yes,},{Age:Senior,CreditRating:Fair,BuysComputer:Yes,},{Age:Senior,CreditRating:Excellent,BuysComputer:No,},{Income:High,Student:No,CreditRating:Fair,},{Income:High,Student:No,BuysComputer:No,},{Income:High,CreditRating:Fair,BuysComputer:Yes,},{Income:Medium,Student:No,CreditRating:Fair,},{Income:Medium,Student:No,CreditRating:Excellent,},{Income:Medium,Student:No,BuysComputer:No,},{Income:Medium,Student:No,BuysComputer:Yes,},{Income:Medium,Student:Yes,BuysComputer:Yes,},{Income:Medium,CreditRating:Fair,BuysComputer:Yes,},{Income:Medium,CreditRating:Excellent,BuysComputer:Yes,},{Income:Low,Student:Yes,CreditRating:Fair,},{Income:Low,Student:Yes,CreditRating:Excellent,},{Income:Low,Student:Yes,BuysComputer:Yes,},{Income:Low,CreditRating:Fair,BuysComputer:Yes,},{Student:No,CreditRating:Fair,BuysComputer:No,},{Student:No,CreditRating:Fair,BuysComputer:Yes,},{Student:No,CreditRating:Excellent,BuysComputer:No,},{Student:Yes,CreditRating:Fair,BuysComputer:Yes,},{Student:Yes,CreditRating:Excellent,BuysComputer:Yes,},
频繁4项集:
{Age:Youth,Income:High,Student:No,BuysComputer:No,},{Age:Youth,Student:No,CreditRating:Fair,BuysComputer:No,},{Age:MiddleAged,Income:High,CreditRating:Fair,BuysComputer:Yes,},{Age:Senior,Income:Medium,CreditRating:Fair,BuysComputer:Yes,},{Age:Senior,Student:Yes,CreditRating:Fair,BuysComputer:Yes,},{Income:Low,Student:Yes,CreditRating:Fair,BuysComputer:Yes,},
频繁5项集:
频繁6项集:
频繁7项集:
频繁8项集:
频繁9项集:
频繁10项集:
频繁11项集:
参考文献:刘兵.<> 第一部分.第二章.关联规则和序列模式
我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
作者:Androidlushangderen 发表于2015/4/16 22:42:53 原文链接
阅读:594 评论:0 查看评论
多维空间分割树--KD树
2015年4月10日 21:39
算法介绍
KD树的全称为k-Dimension Tree的简称,是一种分割K维空间的数据结构,主要应用于关键信息的搜索。为什么说是K维的呢,因为这时候的空间不仅仅是2维度的,他可能是3维,4维度的或者是更多。我们举个例子,如果是二维的空间,对于其中的空间进行分割的就是一条条的分割线,比如说下面这个样子。
如果是3维的呢,那么分割的媒介就是一个平面了,下面是3维空间的分割
这就稍稍有点抽象了,如果是3维以上,我们把这样的分割媒介可以统统叫做超平面 。那么KD树算法有什么特别之处呢,还有他与K-NN算法之间又有什么关系呢,这将是下面所将要描述的。
KNN
KNN就是K最近邻算法,他是一个分类算法,因为算法简单,分类效果也还不错,也被许多人使用着,算法的原理就是选出与给定数据最近的k个数据,然后根据k个数据中占比最多的分类作为测试数据的最终分类。图示如下:
算法固然简单,但是其中通过逐个去比较的办法求得最近的k个数据点,效率太低,时间复杂度会随着训练数据数量的增多而线性增长。于是就需要一种更加高效快速的办法来找到所给查询点的最近邻,而KD树就是其中的一种行之有效的办法。但是不管是KNN算法还是KD树算法,他们都属于相似性查询中的K近邻查询的范畴。在相似性查询算法中还有一类查询是范围查询,就是给定距离阈值和查询点,dbscan算法可以说是一种范围查询,基于给定点进行局部密度范围的搜索。想要了解KNN算法或者是Dbscan算法的可以点击我的K-最近邻算法和Dbscan基于密度的聚类算法。
KD-Tree
在KNN算法中,针对查询点数据的查找采用的是线性扫描的方法,说白了就是暴力比较,KD树在这方面用了二分划分的思想,将数据进行逐层空间上的划分,大大的提高了查询的速度,可以理解为一个变形的二分搜索时间,只不过这个适用到了多维空间的层次上。下面是二维空间的情况下,数据的划分结果:
现在看到的图在逻辑上的意思就是一棵完整的二叉树,虚线上的点是叶子节点。
KD树的算法原理
KD树的算法的实现原理并不是那么好理解,主要分为树的构建和基于KD树进行最近邻的查询2个过程,后者比前者更加复杂。当然,要想实现最近点的查询,首先我们得先理解KD树的构建过程。下面是KD树节点的定义,摘自百度百科:
域名
数据类型
描述
Node-data
数据矢量
数据集中某个数据点,是n维矢量(这里也就是k维)
Range
空间矢量
该节点所代表的空间范围
split
整数
垂直于分割超平面的方向轴序号
Left
k-d树
由位于该节点分割超平面左子空间内所有数据点所构成的k-d树
Right
k-d树
由位于该节点分割超平面右子空间内所有数据点所构成的k-d树
parent
k-d树
父节点
变量还是有点多的,节点中有孩子节点和父亲节点,所以必然会用到递归。KD树的构建算法过程如下(这里假设构建的是2维KD树,简单易懂,后续同上):
1、首先将数据节点坐标中的X坐标和Y坐标进行方差计算,选出其中方差大的,作为分割线的方向,就是接下来将要创建点的split值。
2、将上面的数据点按照分割方向的维度进行排序,选出其中的中位数的点作为数据矢量,就是要分割的分割点。
3、同时进行空间矢量的再次划分,要在父亲节点的空间范围内再进行子分割,就是Range变量,不理解的话,可以阅读我的代码加以理解。
4、对剩余的节点进行左侧空间和右侧空间的分割,进行左孩子和右孩子节点的分割。
5、分割的终点是最终只剩下1个数据点或一侧没有数据点的情况。
在这里举个例子,给定6个数据点:
(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)
对这6个数据点进行最终的KD树的构建效果图如下,左边是实际分割效果,右边是所构成的KD树:
x,y代表的是当前节点的分割方向。读者可以进行手动计算并验证,本人不再加以描述。
KD树构建完毕,之后就是对于给定查询点数据,进行此空间数据的最近数据点,大致过程如下:
1、从根节点开始,从上往下,根据分割方向,在对应维度的坐标点上,进行树的顺序查找,比如给定(3,1),首先来到(7,2),因为根节点的划分方向为X,因此只比较X坐标的划分,因为3<7,所以往左边走,后续的节点同样的道理,最终到达叶子节点为止。
2、当然以这种方式找到的点并不一定是最近的,也许在父节点的另外一个空间内存在更近的点呢,或者说另外一种情况,当前的叶子节点的父亲节点比叶子节点离查询点更近呢,这也是有可能的。
3、所以这个过程会有回溯的步骤,回溯到父节点时候,需要做2点,第一要和父节点比,谁里查询点更近,如果父节点更近,则更改当前找到的最近点,第二以查询点为圆心,当前查询点与最近点的距离为半径画个圆,判断是否与父节点的分割线是否相交,如果相交,则说明有存在父节点另外的孩子空间存在于查询距离更短的点,然后进行父节点空间的又一次深度优先遍历。在局部的遍历查找完毕,在于当前的最近点做比较,比较完之后,继续往上回溯。
下面给出基于上面例子的2个测试例子,查询点为(2.1,3.1)和(2,4.5),前者的例子用于理解一般过程,后面的测试点真正诠释了递归,回溯的过程。先看下(2.1,3.1)的情况:
因为没有碰到任何的父节点分割边界,所以就一直回溯到根节点,最近的节点就是叶子节点(2,3).下面(2,4.5)是需要重点理解的例子,中间出现了一次回溯,和一次再搜索:
在第一次回溯的时候,发现与y=4碰撞到了,进行了又一次的搜寻,结果发现存在更近的点,因此结果变化了,具体的过程可以详细查看百度百科-kd树对这个例子的描述。
算法的代码实现
许多资料都是只有理论,没有实践,本人基于上面的测试例子,自己写了一个,效果还行,基本上实现了上述的过程,不过貌似Range这个变量没有表现出用途来,可以我一番设计,例子完全是上面的例子,输入数据就不放出来了,就是给定的6个坐标点。
坐标点类Point.java:
package DataMining_KDTree;
/**
* 坐标点类
*
* @author lyq
*
*/
public class Point{
// 坐标点横坐标
Double x;
// 坐标点纵坐标
Double y;
public Point(double x, double y){
this.x = x;
this.y = y;
}
public Point(String x, String y) {
this.x = (Double.parseDouble(x));
this.y = (Double.parseDouble(y));
}
/**
* 计算当前点与制定点之间的欧式距离
*
* @param p
* 待计算聚类的p点
* @return
*/
public double ouDistance(Point p) {
double distance = 0;
distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
* (this.y - p.y);
distance = Math.sqrt(distance);
return distance;
}
/**
* 判断2个坐标点是否为用个坐标点
*
* @param p
* 待比较坐标点
* @return
*/
public boolean isTheSame(Point p) {
boolean isSamed = false;
if (this.x == p.x && this.y == p.y) {
isSamed = true;
}
return isSamed;
}
}
空间矢量类Range.java:
package DataMining_KDTree;
/**
* 空间矢量,表示所代表的空间范围
*
* @author lyq
*
*/
public class Range {
// 边界左边界
double left;
// 边界右边界
double right;
// 边界上边界
double top;
// 边界下边界
double bottom;
public Range() {
this.left = -Integer.MAX_VALUE;
this.right = Integer.MAX_VALUE;
this.top = Integer.MAX_VALUE;
this.bottom = -Integer.MAX_VALUE;
}
public Range(int left, int right, int top, int bottom) {
this.left = left;
this.right = right;
this.top = top;
this.bottom = bottom;
}
/**
* 空间矢量进行并操作
*
* @param range
* @return
*/
public Range crossOperation(Range r) {
Range range = new Range();
// 取靠近右侧的左边界
if (r.left > this.left) {
range.left = r.left;
} else {
range.left = this.left;
}
// 取靠近左侧的右边界
if (r.right < this.right) {
range.right = r.right;
} else {
range.right = this.right;
}
// 取靠近下侧的上边界
if (r.top < this.top) {
range.top = r.top;
} else {
range.top = this.top;
}
// 取靠近上侧的下边界
if (r.bottom > this.bottom) {
range.bottom = r.bottom;
} else {
range.bottom = this.bottom;
}
return range;
}
/**
* 根据坐标点分割方向确定左侧空间矢量
*
* @param p
* 数据矢量
* @param dir
* 分割方向
* @return
*/
public static Range initLeftRange(Point p, int dir) {
Range range = new Range();
if (dir == KDTreeTool.DIRECTION_X) {
range.right = p.x;
} else {
range.bottom = p.y;
}
return range;
}
/**
* 根据坐标点分割方向确定右侧空间矢量
*
* @param p
* 数据矢量
* @param dir
* 分割方向
* @return
*/
public static Range initRightRange(Point p, int dir) {
Range range = new Range();
if (dir == KDTreeTool.DIRECTION_X) {
range.left = p.x;
} else {
range.top = p.y;
}
return range;
}
}
KD树节点类TreeNode.java:
package DataMining_KDTree;
/**
* KD树节点
* @author lyq
*
*/
public class TreeNode {
//数据矢量
Point nodeData;
//分割平面的分割线
int spilt;
//空间矢量,该节点所表示的空间范围
Range range;
//父节点
TreeNode parentNode;
//位于分割超平面左侧的孩子节点
TreeNode leftNode;
//位于分割超平面右侧的孩子节点
TreeNode rightNode;
//节点是否被访问过,用于回溯时使用
boolean isVisited;
public TreeNode(){
this.isVisited = false;
}
}
算法封装类KDTreeTool.java:
package DataMining_KDTree;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Stack;
/**
* KD树-k维空间关键数据检索算法工具类
*
* @author lyq
*
*/
public class KDTreeTool {
// 空间平面的方向
public static final int DIRECTION_X = 0;
public static final int DIRECTION_Y = 1;
// 输入的测试数据坐标点文件
private String filePath;
// 原始所有数据点数据
private ArrayList totalDatas;
// KD树根节点
private TreeNode rootNode;
public KDTreeTool(String filePath) {
this.filePath = filePath;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
Point p;
totalDatas = new ArrayList<>();
for (String[] array : dataArray) {
p = new Point(array[0], array[1]);
totalDatas.add(p);
}
}
/**
* 创建KD树
*
* @return
*/
public TreeNode createKDTree() {
ArrayList copyDatas;
rootNode = new TreeNode();
// 根据节点开始时所表示的空间时无限大的
rootNode.range = new Range();
copyDatas = (ArrayList) totalDatas.clone();
recusiveConstructNode(rootNode, copyDatas);
return rootNode;
}
/**
* 递归进行KD树的构造
*
* @param node
* 当前正在构造的节点
* @param datas
* 该节点对应的正在处理的数据
* @return
*/
private void recusiveConstructNode(TreeNode node, ArrayList datas) {
int direction = 0;
ArrayList leftSideDatas;
ArrayList rightSideDatas;
Point p;
TreeNode leftNode;
TreeNode rightNode;
Range range;
Range range2;
// 如果划分的数据点集合只有1个数据,则不再划分
if (datas.size() == 1) {
node.nodeData = datas.get(0);
return;
}
// 首先在当前的数据点集合中进行分割方向的选择
direction = selectSplitDrc(datas);
// 根据方向取出中位数点作为数据矢量
p = getMiddlePoint(datas, direction);
node.spilt = direction;
node.nodeData = p;
leftSideDatas = getLeftSideDatas(datas, p, direction);
datas.removeAll(leftSideDatas);
// 还要去掉自身
datas.remove(p);
rightSideDatas = datas;
if (leftSideDatas.size() > 0) {
leftNode = new TreeNode();
leftNode.parentNode = node;
range2 = Range.initLeftRange(p, direction);
// 获取父节点的空间矢量,进行交集运算做范围拆分
range = node.range.crossOperation(range2);
leftNode.range = range;
node.leftNode = leftNode;
recusiveConstructNode(leftNode, leftSideDatas);
}
if (rightSideDatas.size() > 0) {
rightNode = new TreeNode();
rightNode.parentNode = node;
range2 = Range.initRightRange(p, direction);
// 获取父节点的空间矢量,进行交集运算做范围拆分
range = node.range.crossOperation(range2);
rightNode.range = range;
node.rightNode = rightNode;
recusiveConstructNode(rightNode, rightSideDatas);
}
}
/**
* 搜索出给定数据点的最近点
*
* @param p
* 待比较坐标点
*/
public Point searchNearestData(Point p) {
// 节点距离给定数据点的距离
TreeNode nearestNode = null;
// 用栈记录遍历过的节点
Stack stackNodes;
stackNodes = new Stack<>();
findedNearestLeafNode(p, rootNode, stackNodes);
// 取出叶子节点,作为当前找到的最近节点
nearestNode = stackNodes.pop();
nearestNode = dfsSearchNodes(stackNodes, p, nearestNode);
return nearestNode.nodeData;
}
/**
* 深度优先的方式进行最近点的查找
*
* @param stack
* KD树节点栈
* @param desPoint
* 给定的数据点
* @param nearestNode
* 当前找到的最近节点
* @return
*/
private TreeNode dfsSearchNodes(Stack stack, Point desPoint,
TreeNode nearestNode) {
// 是否碰到父节点边界
boolean isCollision;
double minDis;
double dis;
TreeNode parentNode;
// 如果栈内节点已经全部弹出,则遍历结束
if (stack.isEmpty()) {
return nearestNode;
}
// 获取父节点
parentNode = stack.pop();
minDis = desPoint.ouDistance(nearestNode.nodeData);
dis = desPoint.ouDistance(parentNode.nodeData);
// 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新
if (dis < minDis) {
minDis = dis;
nearestNode = parentNode;
}
// 默认没有碰撞到
isCollision = false;
// 判断是否触碰到了父节点的空间分割线
if (parentNode.spilt == DIRECTION_X) {
if (parentNode.nodeData.x > desPoint.x - minDis
&& parentNode.nodeData.x < desPoint.x + minDis) {
isCollision = true;
}
} else {
if (parentNode.nodeData.y > desPoint.y - minDis
&& parentNode.nodeData.y < desPoint.y + minDis) {
isCollision = true;
}
}
// 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历
if (isCollision
&& (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) {
TreeNode newNode;
// 新建当前的小局部节点栈
Stack otherStack = new Stack<>();
// 从parentNode的树以下继续寻找
findedNearestLeafNode(desPoint, parentNode, otherStack);
newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop());
dis = newNode.nodeData.ouDistance(desPoint);
if (dis < minDis) {
nearestNode = newNode;
}
}
// 继续往上回溯
nearestNode = dfsSearchNodes(stack, desPoint, nearestNode);
return nearestNode;
}
/**
* 找到与所给定节点的最近的叶子节点
*
* @param p
* 待比较节点
* @param node
* 当前搜索到的节点
* @param stack
* 遍历过的节点栈
*/
private void findedNearestLeafNode(Point p, TreeNode node,
Stack stack) {
// 分割方向
int splitDic;
// 将遍历过的节点加入栈中
stack.push(node);
// 标记为访问过
node.isVisited = true;
// 如果此节点没有左右孩子节点说明已经是叶子节点了
if (node.leftNode == null && node.rightNode == null) {
return;
}
splitDic = node.spilt;
// 选择一个符合分割范围的节点继续递归搜寻
if ((splitDic == DIRECTION_X && p.x < node.nodeData.x)
|| (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) {
if (!node.leftNode.isVisited) {
findedNearestLeafNode(p, node.leftNode, stack);
} else {
// 如果左孩子节点已经访问过,则访问另一边
findedNearestLeafNode(p, node.rightNode, stack);
}
} else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x)
|| (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) {
if (!node.rightNode.isVisited) {
findedNearestLeafNode(p, node.rightNode, stack);
} else {
// 如果右孩子节点已经访问过,则访问另一边
findedNearestLeafNode(p, node.leftNode, stack);
}
}
}
/**
* 根据给定的数据点通过计算反差选择的分割点
*
* @param datas
* 部分的集合点集合
* @return
*/
private int selectSplitDrc(ArrayList datas) {
int direction = 0;
double avgX = 0;
double avgY = 0;
double varianceX = 0;
double varianceY = 0;
for (Point p : datas) {
avgX += p.x;
avgY += p.y;
}
avgX /= datas.size();
avgY /= datas.size();
for (Point p : datas) {
varianceX += (p.x - avgX) * (p.x - avgX);
varianceY += (p.y - avgY) * (p.y - avgY);
}
// 求最后的方差
varianceX /= datas.size();
varianceY /= datas.size();
// 通过比较方差的大小决定分割方向,选择波动较大的进行划分
direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y;
return direction;
}
/**
* 根据坐标点方位进行排序,选出中间点的坐标数据
*
* @param datas
* 数据点集合
* @param dir
* 排序的坐标方向
*/
private Point getMiddlePoint(ArrayList datas, int dir) {
int index = 0;
Point middlePoint;
index = datas.size() / 2;
if (dir == DIRECTION_X) {
Collections.sort(datas, new Comparator() {
@Override
public int compare(Point o1, Point o2) {
// TODO Auto-generated method stub
return o1.x.compareTo(o2.x);
}
});
} else {
Collections.sort(datas, new Comparator() {
@Override
public int compare(Point o1, Point o2) {
// TODO Auto-generated method stub
return o1.y.compareTo(o2.y);
}
});
}
// 取出中位数
middlePoint = datas.get(index);
return middlePoint;
}
/**
* 根据方向得到原部分节点集合左侧的数据点
*
* @param datas
* 原始数据点集合
* @param nodeData
* 数据矢量
* @param dir
* 分割方向
* @return
*/
private ArrayList getLeftSideDatas(ArrayList datas,
Point nodeData, int dir) {
ArrayList leftSideDatas = new ArrayList<>();
for (Point p : datas) {
if (dir == DIRECTION_X && p.x < nodeData.x) {
leftSideDatas.add(p);
} else if (dir == DIRECTION_Y && p.y < nodeData.y) {
leftSideDatas.add(p);
}
}
return leftSideDatas;
}
}
场景测试类Client.java:
package DataMining_KDTree;
import java.text.MessageFormat;
/**
* KD树算法测试类
*
* @author lyq
*
*/
public class Client {
public static void main(String[] args) {
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
Point queryNode;
Point searchedNode;
KDTreeTool tool = new KDTreeTool(filePath);
// 进行KD树的构建
tool.createKDTree();
// 通过KD树进行数据点的最近点查询
queryNode = new Point(2.1, 3.1);
searchedNode = tool.searchNearestData(queryNode);
System.out.println(MessageFormat.format(
"距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y,
searchedNode.x, searchedNode.y));
//重新构造KD树,去除之前的访问记录
tool.createKDTree();
queryNode = new Point(2, 4.5);
searchedNode = tool.searchNearestData(queryNode);
System.out.println(MessageFormat.format(
"距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y,
searchedNode.x, searchedNode.y));
}
}
算法的输出结果:
距离查询点(2.1, 3.1)最近的坐标点为(2, 3)
距离查询点(2, 4.5)最近的坐标点为(2, 3)
算法的输出结果与期望值还是一致的。
目前KD-Tree的使用场景是SIFT算法做特征点匹配的时候使用到了,特征点匹配指的是通过距离函数在高维矢量空间进行相似性检索。
参考文献:百度百科 http://baike.baidu.com
我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
作者:Androidlushangderen 发表于2015/4/10 21:39:58 原文链接
阅读:581 评论:0 查看评论
随机森林和GBDT的学习
2015年3月30日 20:28
参考文献:http://www.zilhua.com/629.html
http://www.tuicool.com/articles/JvMJve
http://blog.sina.com.cn/s/blog_573085f70101ivj5.html
我的数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
前言
提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的”树“,指的就是Decision Tree-----决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林=boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradient Boosting Decision Tree,就是梯度提升决策树,与随机森林的思想很像,但是比随机森林稍稍的难一点,当然效果相对于前者而言,也会好许多。由于本人才疏学浅,本文只会详细讲述Random Forest算法的部分,至于GBDT我会给出一小段篇幅做介绍引导,读者能够如果有兴趣的话,可以自行学习。
随机森林算法
决策树
要想理解随机森林算法,就不得不提决策树,什么是决策树,如何构造决策树,简单的回答就是数据的分类以树形结构的方式所展现,每个子分支都代表着不同的分类情况,比如下面的这个图所示:
当然决策树的每个节点分支不一定是三元的,可以有2个或者更多。分类的终止条件为,没有可以再拿来分类的属性条件或者说分到的数据的分类已经完全一致的情况。决策树分类的标准和依据是什么呢,下面介绍主要的2种划分标准。
1、信息增益。这是ID3算法系列所用的方法,C4.5算法在这上面做了少许的改进,用信息增益率来作为划分的标准,可以稍稍减小数据过于拟合的缺点。
2、基尼指数。这是CART分类回归树所用的方法。也是类似于信息增益的一个定义,最终都是根据数据划分后的纯度来做比较,这个纯度,你也可以理解为熵的变化,当然我们所希望的情况就是分类后数据的纯度更纯,也就是说,前后划分分类之后的熵的差越大越好。不过CART算法比较好的一点是树构造好后,还有剪枝的操作,剪枝操作的种类就比较多了,我之前在实现CART算法时用的是代价复杂度的剪枝方法。
这2种决策算法在我之前的博文中已经有所提及,不理解的可以点击我的ID3系列算法介绍和我的CART分类回归树算法。
Boosting
原本不打算将Boosting单独拉出来讲的,后来想想还是有很多内容可谈的。Boosting本身不是一种算法,他更应该说是一种思想,首先对数据构造n个弱分类器,最后通过组合n个弱分类器对于某个数据的判断结果作为最终的分类结果,就变成了一个强分类器,效果自然要好过单一分类器的分类效果。他可以理解为是一种提升算法,举一个比较常见的Boosting思想的算法AdaBoost,他在训练每个弱分类器的时候,提高了对于之前分错数据的权重值,最终能够组成一批相互互补的分类器集合。详细可以查看我的AdaBoost算法学习。
OK,2个重要的概念都已经介绍完毕,终于可以介绍主角Random Forest的出现了,正如前言中所说Random Forest=Decision Trees + Boosting,这里的每个弱分类器就是一个决策树了,不过这里的决策树都是二叉树,就是只有2个孩子分支,自然我立刻想到的做法就是用CART算法来构建,因为人家算法就是二元分支的。随机算法,随机算法,当然重在随机2个字上面,下面是2个方面体现了随机性。对于数据样本的采集量,比如我数据由100条,我可以每次随机取出其中的20条,作为我构造决策树的源数据,采取又放回的方式,并不是第一次抽到的数据,第二次不能重复,第二随机性体现在对于数据属性的随机采集,比如一行数据总共有10个特征属性,我每次随机采用其中的4个。正是由于对于数据的行压缩和列压缩,使得数据的随机性得以保证,就很难出现之前的数据过拟合的问题了,也就不需要在决策树最后进行剪枝操作了,这个是与一般的CART算法所不同的,尤其需要注意。
下面是随机森林算法的构造过程:
1、通过给定的原始数据,选出其中部分数据进行决策树的构造,数据选取是”有放回“的过程,我在这里用的是CART分类回归树。
2、随机森林构造完成之后,给定一组测试数据,使得每个分类器对其结果分类进行评估,最后取评估结果的众数最为最终结果。
算法非常的好理解,在Boosting算法和决策树之上做了一个集成,下面给出算法的实现,很多资料上只有大篇幅的理论,我还是希望能带给大家一点实在的东西。
随机算法的实现
输入数据(之前决策树算法时用过的)input.txt:
Rid Age Income Student CreditRating BuysComputer
1 Youth High No Fair No
2 Youth High No Excellent No
3 MiddleAged High No Fair Yes
4 Senior Medium No Fair Yes
5 Senior Low Yes Fair Yes
6 Senior Low Yes Excellent No
7 MiddleAged Low Yes Excellent Yes
8 Youth Medium No Fair No
9 Youth Low Yes Fair Yes
10 Senior Medium Yes Fair Yes
11 Youth Medium Yes Excellent Yes
12 MiddleAged Medium No Excellent Yes
13 MiddleAged High Yes Fair Yes
14 Senior Medium No Excellent No
树节点类TreeNode.java:
package DataMining_RandomForest;
import java.util.ArrayList;
/**
* 回归分类树节点
*
* @author lyq
*
*/
public class TreeNode {
// 节点属性名字
private String attrName;
// 节点索引标号
private int nodeIndex;
//包含的叶子节点数
private int leafNum;
// 节点误差率
private double alpha;
// 父亲分类属性值
private String parentAttrValue;
// 孩子节点
private TreeNode[] childAttrNode;
// 数据记录索引
private ArrayList dataIndex;
public String getAttrName() {
return attrName;
}
public void setAttrName(String attrName) {
this.attrName = attrName;
}
public int getNodeIndex() {
return nodeIndex;
}
public void setNodeIndex(int nodeIndex) {
this.nodeIndex = nodeIndex;
}
public double getAlpha() {
return alpha;
}
public void setAlpha(double alpha) {
this.alpha = alpha;
}
public String getParentAttrValue() {
return parentAttrValue;
}
public void setParentAttrValue(String parentAttrValue) {
this.parentAttrValue = parentAttrValue;
}
public TreeNode[] getChildAttrNode() {
return childAttrNode;
}
public void setChildAttrNode(TreeNode[] childAttrNode) {
this.childAttrNode = childAttrNode;
}
public ArrayList getDataIndex() {
return dataIndex;
}
public void setDataIndex(ArrayList dataIndex) {
this.dataIndex = dataIndex;
}
public int getLeafNum() {
return leafNum;
}
public void setLeafNum(int leafNum) {
this.leafNum = leafNum;
}
}
决策树类DecisionTree.java:
package DataMining_RandomForest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* 决策树
*
* @author lyq
*
*/
public class DecisionTree {
// 树的根节点
TreeNode rootNode;
// 数据的属性列名称
String[] featureNames;
// 这棵树所包含的数据
ArrayList datas;
// 决策树构造的的工具类
CARTTool tool;
public DecisionTree(ArrayList datas) {
this.datas = datas;
this.featureNames = datas.get(0);
tool = new CARTTool(datas);
// 通过CART工具类进行决策树的构建,并返回树的根节点
rootNode = tool.startBuildingTree();
}
/**
* 根据给定的数据特征描述进行类别的判断
*
* @param features
* @return
*/
public String decideClassType(String features) {
String classType = "";
// 查询属性组
String[] queryFeatures;
// 在本决策树中对应的查询的属性值描述
ArrayList featureStrs;
featureStrs = new ArrayList<>();
queryFeatures = features.split(",");
String[] array;
for (String name : featureNames) {
for (String featureValue : queryFeatures) {
array = featureValue.split("=");
// 将对应的属性值加入到列表中
if (array[0].equals(name)) {
featureStrs.add(array);
}
}
}
// 开始从根据节点往下递归搜索
classType = recusiveSearchClassType(rootNode, featureStrs);
return classType;
}
/**
* 递归搜索树,查询属性的分类类别
*
* @param node
* 当前搜索到的节点
* @param remainFeatures
* 剩余未判断的属性
* @return
*/
private String recusiveSearchClassType(TreeNode node,
ArrayList remainFeatures) {
String classType = null;
// 如果节点包含了数据的id索引,说明已经分类到底了
if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
classType = judgeClassType(node.getDataIndex());
return classType;
}
// 取出剩余属性中的一个匹配属性作为当前的判断属性名称
String[] currentFeature = null;
for (String[] featureValue : remainFeatures) {
if (node.getAttrName().equals(featureValue[0])) {
currentFeature = featureValue;
break;
}
}
for (TreeNode childNode : node.getChildAttrNode()) {
// 寻找子节点中属于此属性值的分支
if (childNode.getParentAttrValue().equals(currentFeature[1])) {
remainFeatures.remove(currentFeature);
classType = recusiveSearchClassType(childNode, remainFeatures);
// 如果找到了分类结果,则直接挑出循环
break;
}else{
//进行第二种情况的判断加上!符号的情况
String value = childNode.getParentAttrValue();
if(value.charAt(0) == '!'){
//去掉第一个!字符
value = value.substring(1, value.length());
if(!value.equals(currentFeature[1])){
remainFeatures.remove(currentFeature);
classType = recusiveSearchClassType(childNode, remainFeatures);
break;
}
}
}
}
return classType;
}
/**
* 根据得到的数据行分类进行类别的决策
*
* @param dataIndex
* 根据分类的数据索引号
* @return
*/
public String judgeClassType(ArrayList dataIndex) {
// 结果类型值
String resultClassType = "";
String classType = "";
int count = 0;
int temp = 0;
Map type2Num = new HashMap();
for (String index : dataIndex) {
temp = Integer.parseInt(index);
// 取最后一列的决策类别数据
classType = datas.get(temp)[featureNames.length - 1];
if (type2Num.containsKey(classType)) {
// 如果类别已经存在,则使其计数加1
count = type2Num.get(classType);
count++;
} else {
count = 1;
}
type2Num.put(classType, count);
}
// 选出其中类别支持计数最多的一个类别值
count = -1;
for (Map.Entry entry : type2Num.entrySet()) {
if ((int) entry.getValue() > count) {
count = (int) entry.getValue();
resultClassType = (String) entry.getKey();
}
}
return resultClassType;
}
}
随机森林算法工具类RandomForestTool.java:
package DataMining_RandomForest;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* 随机森林算法工具类
*
* @author lyq
*
*/
public class RandomForestTool {
// 测试数据文件地址
private String filePath;
// 决策树的样本占总数的占比率
private double sampleNumRatio;
// 样本数据的采集特征数量占总特征的比例
private double featureNumRatio;
// 决策树的采样样本数
private int sampleNum;
// 样本数据的采集采样特征数
private int featureNum;
// 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量
private int treeNum;
// 随机数产生器
private Random random;
// 样本数据列属性名称行
private String[] featureNames;
// 原始的总的数据
private ArrayList totalDatas;
// 决策树森林
private ArrayList decisionForest;
public RandomForestTool(String filePath, double sampleNumRatio,
double featureNumRatio) {
this.filePath = filePath;
this.sampleNumRatio = sampleNumRatio;
this.featureNumRatio = featureNumRatio;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
totalDatas = dataArray;
featureNames = totalDatas.get(0);
sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio);
//算属性数量的时候需要去掉id属性和决策属性,用条件属性计算
featureNum = (int) ((featureNames.length -2) * featureNumRatio);
// 算数量的时候需要去掉首行属性名称行
treeNum = (totalDatas.size() - 1) / sampleNum;
}
/**
* 产生决策树
*/
private DecisionTree produceDecisionTree() {
int temp = 0;
DecisionTree tree;
String[] tempData;
//采样数据的随机行号组
ArrayList sampleRandomNum;
//采样属性特征的随机列号组
ArrayList featureRandomNum;
ArrayList datas;
sampleRandomNum = new ArrayList<>();
featureRandomNum = new ArrayList<>();
datas = new ArrayList<>();
for(int i=0; i temp = random.nextInt(totalDatas.size());
//如果是行首属性名称行,则跳过
if(temp == 0){
continue;
}
if(!sampleRandomNum.contains(temp)){
sampleRandomNum.add(temp);
i++;
}
}
for(int i=0; i temp = random.nextInt(featureNames.length);
//如果是第一列的数据id号或者是决策属性列,则跳过
if(temp == 0 || temp == featureNames.length-1){
continue;
}
if(!featureRandomNum.contains(temp)){
featureRandomNum.add(temp);
i++;
}
}
String[] singleRecord;
String[] headCulumn = null;
// 获取随机数据行
for(int dataIndex: sampleRandomNum){
singleRecord = totalDatas.get(dataIndex);
//每行的列数=所选的特征数+id号
tempData = new String[featureNum+2];
headCulumn = new String[featureNum+2];
for(int i=0,k=1; i temp = featureRandomNum.get(i);
headCulumn[k] = featureNames[temp];
tempData[k] = singleRecord[temp];
}
//加上id列的信息
headCulumn[0] = featureNames[0];
//加上决策分类列的信息
headCulumn[featureNum+1] = featureNames[featureNames.length-1];
tempData[featureNum+1] = singleRecord[featureNames.length-1];
//加入此行数据
datas.add(tempData);
}
//加入行首列出现名称
datas.add(0, headCulumn);
//对筛选出的数据重新做id分配
temp = 0;
for(String[] array: datas){
//从第2行开始赋值
if(temp > 0){
array[0] = temp + "";
}
temp++;
}
tree = new DecisionTree(datas);
return tree;
}
/**
* 构造随机森林
*/
public void constructRandomTree() {
DecisionTree tree;
random = new Random();
decisionForest = new ArrayList<>();
System.out.println("下面是随机森林中的决策树:");
// 构造决策树加入森林中
for (int i = 0; i < treeNum; i++) {
System.out.println("\n决策树" + (i+1));
tree = produceDecisionTree();
decisionForest.add(tree);
}
}
/**
* 根据给定的属性条件进行类别的决策
*
* @param features
* 给定的已知的属性描述
* @return
*/
public String judgeClassType(String features) {
// 结果类型值
String resultClassType = "";
String classType = "";
int count = 0;
Map type2Num = new HashMap();
for (DecisionTree tree : decisionForest) {
classType = tree.decideClassType(features);
if (type2Num.containsKey(classType)) {
// 如果类别已经存在,则使其计数加1
count = type2Num.get(classType);
count++;
} else {
count = 1;
}
type2Num.put(classType, count);
}
// 选出其中类别支持计数最多的一个类别值
count = -1;
for (Map.Entry entry : type2Num.entrySet()) {
if ((int) entry.getValue() > count) {
count = (int) entry.getValue();
resultClassType = (String) entry.getKey();
}
}
return resultClassType;
}
}
CART算法工具类CARTTool.java:
package DataMining_RandomForest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;
/**
* CART分类回归树算法工具类
*
* @author lyq
*
*/
public class CARTTool {
// 类标号的值类型
private final String YES = "Yes";
private final String NO = "No";
// 所有属性的类型总数,在这里就是data源数据的列数
private int attrNum;
private String filePath;
// 初始源数据,用一个二维字符数组存放模仿表格数据
private String[][] data;
// 数据的属性行的名字
private String[] attrNames;
// 每个属性的值所有类型
private HashMap> attrValue;
public CARTTool(ArrayList dataArray) {
attrValue = new HashMap<>();
readData(dataArray);
}
/**
* 根据随机选取的样本数据进行初始化
* @param dataArray
* 已经读入的样本数据
*/
public void readData(ArrayList dataArray) {
data = new String[dataArray.size()][];
dataArray.toArray(data);
attrNum = data[0].length;
attrNames = data[0];
}
/**
* 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
*/
public void initAttrValue() {
ArrayList tempValues;
// 按照列的方式,从左往右找
for (int j = 1; j < attrNum; j++) {
// 从一列中的上往下开始寻找值
tempValues = new ArrayList<>();
for (int i = 1; i < data.length; i++) {
if (!tempValues.contains(data[i][j])) {
// 如果这个属性的值没有添加过,则添加
tempValues.add(data[i][j]);
}
}
// 一列属性的值已经遍历完毕,复制到map属性表中
attrValue.put(data[0][j], tempValues);
}
}
/**
* 计算机基尼指数
*
* @param remainData
* 剩余数据
* @param attrName
* 属性名称
* @param value
* 属性值
* @param beLongValue
* 分类是否属于此属性值
* @return
*/
public double computeGini(String[][] remainData, String attrName,
String value, boolean beLongValue) {
// 实例总数
int total = 0;
// 正实例数
int posNum = 0;
// 负实例数
int negNum = 0;
// 基尼指数
double gini = 0;
// 还是按列从左往右遍历属性
for (int j = 1; j < attrNames.length; j++) {
// 找到了指定的属性
if (attrName.equals(attrNames[j])) {
for (int i = 1; i < remainData.length; i++) {
// 统计正负实例按照属于和不属于值类型进行划分
if ((beLongValue && remainData[i][j].equals(value))
|| (!beLongValue && !remainData[i][j].equals(value))) {
if (remainData[i][attrNames.length - 1].equals(YES)) {
// 判断此行数据是否为正实例
posNum++;
} else {
negNum++;
}
}
}
}
}
total = posNum + negNum;
double posProbobly = (double) posNum / total;
double negProbobly = (double) negNum / total;
gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;
// 返回计算基尼指数
return gini;
}
/**
* 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
*
* @param remainData
* 剩余谁
* @param attrName
* 属性名称
* @return
*/
public String[] computeAttrGini(String[][] remainData, String attrName) {
String[] str = new String[2];
// 最终该属性的划分类型值
String spiltValue = "";
// 临时变量
int tempNum = 0;
// 保存属性的值划分时的最小的基尼指数
double minGini = Integer.MAX_VALUE;
ArrayList valueTypes = attrValue.get(attrName);
// 属于此属性值的实例数
HashMap belongNum = new HashMap<>();
for (String string : valueTypes) {
// 重新计数的时候,数字归0
tempNum = 0;
// 按列从左往右遍历属性
for (int j = 1; j < attrNames.length; j++) {
// 找到了指定的属性
if (attrName.equals(attrNames[j])) {
for (int i = 1; i < remainData.length; i++) {
// 统计正负实例按照属于和不属于值类型进行划分
if (remainData[i][j].equals(string)) {
tempNum++;
}
}
}
}
belongNum.put(string, tempNum);
}
double tempGini = 0;
double posProbably = 1.0;
double negProbably = 1.0;
for (String string : valueTypes) {
tempGini = 0;
posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
negProbably = 1 - posProbably;
tempGini += posProbably
* computeGini(remainData, attrName, string, true);
tempGini += negProbably
* computeGini(remainData, attrName, string, false);
if (tempGini < minGini) {
minGini = tempGini;
spiltValue = string;
}
}
str[0] = spiltValue;
str[1] = minGini + "";
return str;
}
public void buildDecisionTree(TreeNode node, String parentAttrValue,
String[][] remainData, ArrayList remainAttr,
boolean beLongParentValue) {
// 属性划分值
String valueType = "";
// 划分属性名称
String spiltAttrName = "";
double minGini = Integer.MAX_VALUE;
double tempGini = 0;
// 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
String[] giniArray;
if (beLongParentValue) {
node.setParentAttrValue(parentAttrValue);
} else {
node.setParentAttrValue("!" + parentAttrValue);
}
if (remainAttr.size() == 0) {
if (remainData.length > 1) {
ArrayList indexArray = new ArrayList<>();
for (int i = 1; i < remainData.length; i++) {
indexArray.add(remainData[i][0]);
}
node.setDataIndex(indexArray);
}
// System.out.println("attr remain null");
return;
}
for (String str : remainAttr) {
giniArray = computeAttrGini(remainData, str);
tempGini = Double.parseDouble(giniArray[1]);
if (tempGini < minGini) {
spiltAttrName = str;
minGini = tempGini;
valueType = giniArray[0];
}
}
// 移除划分属性
remainAttr.remove(spiltAttrName);
node.setAttrName(spiltAttrName);
// 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点
TreeNode[] childNode = new TreeNode[2];
String[][] rData;
boolean[] bArray = new boolean[] { true, false };
for (int i = 0; i < bArray.length; i++) {
// 二元划分属于属性值的划分
rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);
boolean sameClass = true;
ArrayList indexArray = new ArrayList<>();
for (int k = 1; k < rData.length; k++) {
indexArray.add(rData[k][0]);
// 判断是否为同一类的
if (!rData[k][attrNames.length - 1]
.equals(rData[1][attrNames.length - 1])) {
// 只要有1个不相等,就不是同类型的
sameClass = false;
break;
}
}
childNode[i] = new TreeNode();
if (!sameClass) {
// 创建新的对象属性,对象的同个引用会出错
ArrayList rAttr = new ArrayList<>();
for (String str : remainAttr) {
rAttr.add(str);
}
buildDecisionTree(childNode[i], valueType, rData, rAttr,
bArray[i]);
} else {
String pAtr = (bArray[i] ? valueType : "!" + valueType);
childNode[i].setParentAttrValue(pAtr);
childNode[i].setDataIndex(indexArray);
}
}
node.setChildAttrNode(childNode);
}
/**
* 属性划分完毕,进行数据的移除
*
* @param srcData
* 源数据
* @param attrName
* 划分的属性名称
* @param valueType
* 属性的值类型
* @parame beLongValue 分类是否属于此值类型
*/
private String[][] removeData(String[][] srcData, String attrName,
String valueType, boolean beLongValue) {
String[][] desDataArray;
ArrayList desData = new ArrayList<>();
// 待删除数据
ArrayList selectData = new ArrayList<>();
selectData.add(attrNames);
// 数组数据转化到列表中,方便移除
for (int i = 0; i < srcData.length; i++) {
desData.add(srcData[i]);
}
// 还是从左往右一列列的查找
for (int j = 1; j < attrNames.length; j++) {
if (attrNames[j].equals(attrName)) {
for (int i = 1; i < desData.size(); i++) {
if (desData.get(i)[j].equals(valueType)) {
// 如果匹配这个数据,则移除其他的数据
selectData.add(desData.get(i));
}
}
}
}
if (beLongValue) {
desDataArray = new String[selectData.size()][];
selectData.toArray(desDataArray);
} else {
// 属性名称行不移除
selectData.remove(attrNames);
// 如果是划分不属于此类型的数据时,进行移除
desData.removeAll(selectData);
desDataArray = new String[desData.size()][];
desData.toArray(desDataArray);
}
return desDataArray;
}
/**
* 构造分类回归树,并返回根节点
* @return
*/
public TreeNode startBuildingTree() {
initAttrValue();
ArrayList remainAttr = new ArrayList<>();
// 添加属性,除了最后一个类标号属性
for (int i = 1; i < attrNames.length - 1; i++) {
remainAttr.add(attrNames[i]);
}
TreeNode rootNode = new TreeNode();
buildDecisionTree(rootNode, "", data, remainAttr, false);
setIndexAndAlpah(rootNode, 0, false);
showDecisionTree(rootNode, 1);
return rootNode;
}
/**
* 显示决策树
*
* @param node
* 待显示的节点
* @param blankNum
* 行空格符,用于显示树型结构
*/
private void showDecisionTree(TreeNode node, int blankNum) {
System.out.println();
for (int i = 0; i < blankNum; i++) {
System.out.print(" ");
}
System.out.print("--");
// 显示分类的属性值
if (node.getParentAttrValue() != null
&& node.getParentAttrValue().length() > 0) {
System.out.print(node.getParentAttrValue());
} else {
System.out.print("--");
}
System.out.print("--");
if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
String i = node.getDataIndex().get(0);
System.out.print("【" + node.getNodeIndex() + "】类别:"
+ data[Integer.parseInt(i)][attrNames.length - 1]);
System.out.print("[");
for (String index : node.getDataIndex()) {
System.out.print(index + ", ");
}
System.out.print("]");
} else {
// 递归显示子节点
System.out.print("【" + node.getNodeIndex() + ":"
+ node.getAttrName() + "】");
if (node.getChildAttrNode() != null) {
for (TreeNode childNode : node.getChildAttrNode()) {
showDecisionTree(childNode, 2 * blankNum);
}
} else {
System.out.print("【 Child Null】");
}
}
}
/**
* 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
*
* @param node
* 开始的时候传入的是根节点
* @param index
* 开始的索引号,从1开始
* @param ifCutNode
* 是否需要剪枝
*/
private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) {
TreeNode tempNode;
// 最小误差代价节点,即将被剪枝的节点
TreeNode minAlphaNode = null;
double minAlpah = Integer.MAX_VALUE;
Queue nodeQueue = new LinkedList();
nodeQueue.add(node);
while (nodeQueue.size() > 0) {
index++;
// 从队列头部获取首个节点
tempNode = nodeQueue.poll();
tempNode.setNodeIndex(index);
if (tempNode.getChildAttrNode() != null) {
for (TreeNode childNode : tempNode.getChildAttrNode()) {
nodeQueue.add(childNode);
}
computeAlpha(tempNode);
if (tempNode.getAlpha() < minAlpah) {
minAlphaNode = tempNode;
minAlpah = tempNode.getAlpha();
} else if (tempNode.getAlpha() == minAlpah) {
// 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点
if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
minAlphaNode = tempNode;
}
}
}
}
if (ifCutNode) {
// 进行树的剪枝,让其左右孩子节点为null
minAlphaNode.setChildAttrNode(null);
}
}
/**
* 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
*
* @param node
* 待计算的非叶子节点
*/
private void computeAlpha(TreeNode node) {
double rt = 0;
double Rt = 0;
double alpha = 0;
// 当前节点的数据总数
int sumNum = 0;
// 最少的偏差数
int minNum = 0;
ArrayList dataIndex;
ArrayList leafNodes = new ArrayList<>();
addLeafNode(node, leafNodes);
node.setLeafNum(leafNodes.size());
for (TreeNode attrNode : leafNodes) {
dataIndex = attrNode.getDataIndex();
int num = 0;
sumNum += dataIndex.size();
for (String s : dataIndex) {
// 统计分类数据中的正负实例数
if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
num++;
}
}
minNum += num;
// 取小数量的值部分
if (1.0 * num / dataIndex.size() > 0.5) {
num = dataIndex.size() - num;
}
rt += (1.0 * num / (data.length - 1));
}
//同样取出少偏差的那部分
if (1.0 * minNum / sumNum > 0.5) {
minNum = sumNum - minNum;
}
Rt = 1.0 * minNum / (data.length - 1);
alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
node.setAlpha(alpha);
}
/**
* 筛选出节点所包含的叶子节点数
*
* @param node
* 待筛选节点
* @param leafNode
* 叶子节点列表容器
*/
private void addLeafNode(TreeNode node, ArrayList leafNode) {
ArrayList dataIndex;
if (node.getChildAttrNode() != null) {
for (TreeNode childNode : node.getChildAttrNode()) {
dataIndex = childNode.getDataIndex();
if (dataIndex != null && dataIndex.size() > 0) {
// 说明此节点为叶子节点
leafNode.add(childNode);
} else {
// 如果还是非叶子节点则继续递归调用
addLeafNode(childNode, leafNode);
}
}
}
}
}
测试类Client.java:
package DataMining_RandomForest;
import java.text.MessageFormat;
/**
* 随机森林算法测试场景
*
* @author lyq
*
*/
public class Client {
public static void main(String[] args) {
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair";
String resultClassType = "";
// 决策树的样本占总数的占比率
double sampleNumRatio = 0.4;
// 样本数据的采集特征数量占总特征的比例
double featureNumRatio = 0.5;
RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio,
featureNumRatio);
tool.constructRandomTree();
resultClassType = tool.judgeClassType(queryStr);
System.out.println();
System.out
.println(MessageFormat.format(
"查询属性描述{0},预测的分类结果为BuysCompute:{1}", queryStr,
resultClassType));
}
}
算法的输出
下面是随机森林中的决策树:
决策树1
--!--【1:Income】
--Medium--【2】类别:Yes[1, 2, ]
--!Medium--【3:Student】
--No--【4】类别:No[3, 5, ]
--!No--【5】类别:Yes[4, ]
决策树2
--!--【1:Student】
--No--【2】类别:No[1, 3, ]
--!No--【3】类别:Yes[2, 4, 5, ]
查询属性描述Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No
输出的结果决策树建议从左往右看,从上往下,【】符号表示一个节点,---XX---表示属性值的划分,你就应该能看懂这棵树了,在console上想展示漂亮的树形效果的确很难。。。这里说一个算法的重大不足,数据太少,导致选择的样本数据不足,所选属性太少,,构造的决策树数量过少,自然分类的准确率不见得会有多准,博友只要能领会代码中所表达的算法的思想即可。
GBDT
下面来说说随机森林的兄弟算法GBDT,梯度提升决策树,他有很多的决策树,他也有组合的思想,但是他不是随机森林算法2,GBDT的关键在于Gradient Boosting,梯度提升。这个词语理解起来就不容易了。学术的描述,每一次建立模型是在之前建立模型的损失函数的梯度下降方向。GBDT的核心在于,每一棵树学的是之前所有树结论和的残差,这个残差你可以理解为与预测值的差值。举个例子:比如预测张三的年龄,张三的真实年龄18岁,第一棵树预测张的年龄12岁,此时残差为18-12=6岁,因此在第二棵树中,我们把张的年龄作为6岁去学习,如果预测成功了,则张的真实年龄就是A树和B树的结果预测值的和,但是如果B预测成了5岁,那么残差就变成了6-5=1岁,那么此时需要构建第三树对1岁做预测,后面一样的道理。每棵树都是对之前失败预测的一个补充,用公式的表达就是如下的这个样子:
F0在这里是初始值,Ti是一棵棵的决策树,不同的问题选择不同的损失函数和初始值。在阿里内部对于此算法的叫法为TreeLink。所以下次听到什么Treelink算法了指的就是梯度提升树算法,其实我在这里省略了很大篇幅的数学推导过程,再加上自己还不是专家,无法彻底解释清数学的部分,所以就没有提及,希望以后有时间可以深入学习此方面的知识。
作者:Androidlushangderen 发表于2015/3/30 20:28:53 原文链接
阅读:1064 评论:0 查看评论
遗传算法在走迷宫游戏中的应用
2015年3月26日 21:56
我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
前言
遗传(GA)算法是一个非常有意思的算法,因为他利用了生物进化理论的知识进行问题的求解。算法的核心就是把拥有更好环境适应度的基因遗传给下一代,这就是其中的关键的选择操作,遗传算法整体的阶段分为选择,交叉和变异操作,选择操作和变异操作在其中又是比较重要的步骤。本篇文章不会讲述GA算法的具体细节,之前我曾经写过一篇专门的文章介绍过此算法,链接:http://blog.csdn.net/androidlushangderen/article/details/44041499,里面介绍了一些基本的概念和算法的原理过程,如果你对GA算法掌握的还不错的话,那么对于理解后面遗传算法在走迷宫的应用来说应该不是难事。
算法在迷宫游戏中的应用
先说说走迷宫游戏要解决的问题是什么, 走迷宫游戏说白了就是给定起点,终点,中间设置一堆的障碍,然后要求可达的路径,注意这里指的是可达路径,并没有说一定是最优路径,因为最优路径一定是用步数最少的,这一点还是很不同的。而另一方面,遗传算法也是用来搜索问题最优解的,所以刚刚好可以转移到这个问题上。用一个遗传算法去解决生活中的实际问题最关键的就是如何用遗传算法中的概念表示出来,比如遗传算法中核心的几个概念,基因编码,基因长度的设置,适应度函数的定义,3个概念每个都很重要。好的,目的要求已经慢慢的明确了,下面一个个问题的解决。
为了能让大家更好的理解,下面举出一个例子,如图所示:
图是自己做的,比较简略,以左边点的形式表示,从图中可以看出,起点位置(4, 4),出口左边为绿色区域位置(1,0),X符号表示的障碍区域,不允许经过,问题就转为搜索出从起点到终点位置的最短路径,因为本身例子构造的不是很复杂,我们按照对角线的方式出发,总共的步数=4-1 + 4-0=7步,只要中间不拐弯,每一步都是靠近目标点方向的移动就是最佳的方式。下面看看如何转化成遗传算法中的概念表示。
个体基因长度
首先是基于长度,因为最后筛选出的是一个个体,就是满足条件的个体,他的基因编码就是问题的最优解,所以就能联想把角色的每一步移动操作看出是一个基因编码,总共7步就需要7个基因值表示,所以基因的长度在本例子中就是7。
基因表示
已经将角色的每一次的移动步骤转化为基因的表示,每次的移动总共有4种可能,上下左右,基因编码是标准的二进制形式,所以可以取值为00代表向上,01向下,10向左,11向右,也就是说,每个基因组用2个编码表示,所以总共的编码数字就是2*7=14个,两两一对。
适应度函数
适应度函数的设置应该是在遗传算法中最重要了吧,以为他的设置好坏直接决定着遗传质量的好坏,基因组表示的移动的操作步骤,给定起点位置,通过基因组的编码组数据,我们可以计算出最终的抵达坐标,这里可以很容易的得出结论,如果最后的抵达坐标越接近出口坐标,就越是我们想要的结果,也就是适应值越高,所以我们可以用下面的公式作为适应度函数:
(x, y)为计算出的适应值的函数值在0到1之间波动,1为最大值,就是抵达的坐标恰好是出口位置的时候,当然适应度函数的表示不是唯一的。
算法的代码实现
算法地图数据的输入mapData.txt:
0 0 0 0 0
2 0 0 -1 0
0 0 0 0 0
0 -1 0 0 -1
0 0 0 0 1
就是上面图示的那个例子.
算法的主要实现类GATool.java:
package GA_Maze;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Random;
/**
* 遗传算法在走迷宫游戏的应用-遗传算法工具类
*
* @author lyq
*
*/
public class GATool {
// 迷宫出入口标记
public static final int MAZE_ENTRANCE_POS = 1;
public static final int MAZE_EXIT_POS = 2;
// 方向对应的编码数组
public static final int[][] MAZE_DIRECTION_CODE = new int[][] { { 0, 0 },
{ 0, 1 }, { 1, 0 }, { 1, 1 }, };
// 坐标点方向改变
public static final int[][] MAZE_DIRECTION_CHANGE = new int[][] {
{ -1, 0 }, { 1, 0 }, { 0, -1 }, { 0, 1 }, };
// 方向的文字描述
public static final String[] MAZE_DIRECTION_LABEL = new String[] { "上",
"下", "左", "右" };
// 地图数据文件地址
private String filePath;
// 走迷宫的最短步数
private int stepNum;
// 初始个体的数量
private int initSetsNum;
// 迷宫入口位置
private int[] startPos;
// 迷宫出口位置
private int[] endPos;
// 迷宫地图数据
private int[][] mazeData;
// 初始个体集
private ArrayList initSets;
// 随机数产生器
private Random random;
public GATool(String filePath, int initSetsNum) {
this.filePath = filePath;
this.initSetsNum = initSetsNum;
readDataFile();
}
/**
* 从文件中读取数据
*/
public void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
int rowNum = dataArray.size();
mazeData = new int[rowNum][rowNum];
for (int i = 0; i < rowNum; i++) {
String[] data = dataArray.get(i);
for (int j = 0; j < data.length; j++) {
mazeData[i][j] = Integer.parseInt(data[j]);
// 赋值入口和出口位置
if (mazeData[i][j] == MAZE_ENTRANCE_POS) {
startPos = new int[2];
startPos[0] = i;
startPos[1] = j;
} else if (mazeData[i][j] == MAZE_EXIT_POS) {
endPos = new int[2];
endPos[0] = i;
endPos[1] = j;
}
}
}
// 计算走出迷宫的最短步数
stepNum = Math.abs(startPos[0] - endPos[0])
+ Math.abs(startPos[1] - endPos[1]);
}
/**
* 产生初始数据集
*/
private void produceInitSet() {
// 方向编码
int directionCode = 0;
random = new Random();
initSets = new ArrayList<>();
// 每个步骤的操作需要用2位数字表示
int[] codeNum;
for (int i = 0; i < initSetsNum; i++) {
codeNum = new int[stepNum * 2];
for (int j = 0; j < stepNum; j++) {
directionCode = random.nextInt(4);
codeNum[2 * j] = MAZE_DIRECTION_CODE[directionCode][0];
codeNum[2 * j + 1] = MAZE_DIRECTION_CODE[directionCode][1];
}
initSets.add(codeNum);
}
}
/**
* 选择操作,把适值较高的个体优先遗传到下一代
*
* @param initCodes
* 初始个体编码
* @return
*/
private ArrayList selectOperate(ArrayList initCodes) {
double randomNum = 0;
double sumFitness = 0;
ArrayList resultCodes = new ArrayList<>();
double[] adaptiveValue = new double[initSetsNum];
for (int i = 0; i < initSetsNum; i++) {
adaptiveValue[i] = calFitness(initCodes.get(i));
sumFitness += adaptiveValue[i];
}
// 转成概率的形式,做归一化操作
for (int i = 0; i < initSetsNum; i++) {
adaptiveValue[i] = adaptiveValue[i] / sumFitness;
}
for (int i = 0; i < initSetsNum; i++) {
randomNum = random.nextInt(100) + 1;
randomNum = randomNum / 100;
//因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断
if(randomNum == 1){
randomNum = randomNum - 0.01;
}
sumFitness = 0;
// 确定区间
for (int j = 0; j < initSetsNum; j++) {
if (randomNum > sumFitness
&& randomNum <= sumFitness + adaptiveValue[j]) {
// 采用拷贝的方式避免引用重复
resultCodes.add(initCodes.get(j).clone());
break;
} else {
sumFitness += adaptiveValue[j];
}
}
}
return resultCodes;
}
/**
* 交叉运算
*
* @param selectedCodes
* 上步骤的选择后的编码
* @return
*/
private ArrayList crossOperate(ArrayList selectedCodes) {
int randomNum = 0;
// 交叉点
int crossPoint = 0;
ArrayList resultCodes = new ArrayList<>();
// 随机编码队列,进行随机交叉配对
ArrayList randomCodeSeqs = new ArrayList<>();
// 进行随机排序
while (selectedCodes.size() > 0) {
randomNum = random.nextInt(selectedCodes.size());
randomCodeSeqs.add(selectedCodes.get(randomNum));
selectedCodes.remove(randomNum);
}
int temp = 0;
int[] array1;
int[] array2;
// 进行两两交叉运算
for (int i = 1; i < randomCodeSeqs.size(); i++) {
if (i % 2 == 1) {
array1 = randomCodeSeqs.get(i - 1);
array2 = randomCodeSeqs.get(i);
crossPoint = random.nextInt(stepNum - 1) + 1;
// 进行交叉点位置后的编码调换
for (int j = 0; j < 2 * stepNum; j++) {
if (j >= 2 * crossPoint) {
temp = array1[j];
array1[j] = array2[j];
array2[j] = temp;
}
}
// 加入到交叉运算结果中
resultCodes.add(array1);
resultCodes.add(array2);
}
}
return resultCodes;
}
/**
* 变异操作
*
* @param crossCodes
* 交叉运算后的结果
* @return
*/
private ArrayList variationOperate(ArrayList crossCodes) {
// 变异点
int variationPoint = 0;
ArrayList resultCodes = new ArrayList<>();
for (int[] array : crossCodes) {
variationPoint = random.nextInt(stepNum);
for (int i = 0; i < array.length; i += 2) {
// 变异点进行变异
if (i % 2 == 0 && i / 2 == variationPoint) {
array[i] = (array[i] == 0 ? 1 : 0);
array[i + 1] = (array[i + 1] == 0 ? 1 : 0);
break;
}
}
resultCodes.add(array);
}
return resultCodes;
}
/**
* 根据编码计算适值
*
* @param code
* 当前的编码
* @return
*/
public double calFitness(int[] code) {
double fintness = 0;
// 由编码计算所得的终点横坐标
int endX = 0;
// 由编码计算所得的终点纵坐标
int endY = 0;
// 基于片段所代表的行走方向
int direction = 0;
// 临时坐标点横坐标
int tempX = 0;
// 临时坐标点纵坐标
int tempY = 0;
endX = startPos[0];
endY = startPos[1];
for (int i = 0; i < stepNum; i++) {
direction = binaryArrayToNum(new int[] { code[2 * i],
code[2 * i + 1] });
// 根据方向改变数组做坐标点的改变
tempX = endX + MAZE_DIRECTION_CHANGE[direction][0];
tempY = endY + MAZE_DIRECTION_CHANGE[direction][1];
// 判断坐标点是否越界
if (tempX >= 0 && tempX < mazeData.length && tempY >= 0
&& tempY < mazeData[0].length) {
// 判断坐标点是否走到阻碍块
if (mazeData[tempX][tempY] != -1) {
endX = tempX;
endY = tempY;
}
}
}
// 根据适值函数进行适值的计算
fintness = 1.0 / (Math.abs(endX - endPos[0])
+ Math.abs(endY - endPos[1]) + 1);
return fintness;
}
/**
* 根据当前编码判断是否已经找到出口位置
*
* @param code
* 经过若干次遗传的编码
* @return
*/
private boolean ifArriveEndPos(int[] code) {
boolean isArrived = false;
// 由编码计算所得的终点横坐标
int endX = 0;
// 由编码计算所得的终点纵坐标
int endY = 0;
// 基于片段所代表的行走方向
int direction = 0;
// 临时坐标点横坐标
int tempX = 0;
// 临时坐标点纵坐标
int tempY = 0;
endX = startPos[0];
endY = startPos[1];
for (int i = 0; i < stepNum; i++) {
direction = binaryArrayToNum(new int[] { code[2 * i],
code[2 * i + 1] });
// 根据方向改变数组做坐标点的改变
tempX = endX + MAZE_DIRECTION_CHANGE[direction][0];
tempY = endY + MAZE_DIRECTION_CHANGE[direction][1];
// 判断坐标点是否越界
if (tempX >= 0 && tempX < mazeData.length && tempY >= 0
&& tempY < mazeData[0].length) {
// 判断坐标点是否走到阻碍块
if (mazeData[tempX][tempY] != -1) {
endX = tempX;
endY = tempY;
}
}
}
if (endX == endPos[0] && endY == endPos[1]) {
isArrived = true;
}
return isArrived;
}
/**
* 二进制数组转化为数字
*
* @param binaryArray
* 待转化二进制数组
*/
private int binaryArrayToNum(int[] binaryArray) {
int result = 0;
for (int i = binaryArray.length - 1, k = 0; i >= 0; i--, k++) {
if (binaryArray[i] == 1) {
result += Math.pow(2, k);
}
}
return result;
}
/**
* 进行遗传算法走出迷宫
*/
public void goOutMaze() {
// 迭代遗传次数
int loopCount = 0;
boolean canExit = false;
// 结果路径
int[] resultCode = null;
ArrayList initCodes;
ArrayList selectedCodes;
ArrayList crossedCodes;
ArrayList variationCodes;
// 产生初始数据集
produceInitSet();
initCodes = initSets;
while (true) {
for (int[] array : initCodes) {
// 遗传迭代的终止条件为是否找到出口位置
if (ifArriveEndPos(array)) {
resultCode = array;
canExit = true;
break;
}
}
if (canExit) {
break;
}
selectedCodes = selectOperate(initCodes);
crossedCodes = crossOperate(selectedCodes);
variationCodes = variationOperate(crossedCodes);
initCodes = variationCodes;
loopCount++;
//如果遗传次数超过100次,则退出
if(loopCount >= 100){
break;
}
}
System.out.println("总共遗传进化了" + loopCount + "次");
printFindedRoute(resultCode);
}
/**
* 输出找到的路径
*
* @param code
*/
private void printFindedRoute(int[] code) {
if(code == null){
System.out.println("在有限的遗传进化次数内,没有找到最优路径");
return;
}
int tempX = startPos[0];
int tempY = startPos[1];
int direction = 0;
System.out.println(MessageFormat.format(
"起始点位置({0},{1}), 出口点位置({2}, {3})", tempX, tempY, endPos[0],
endPos[1]));
System.out.print("搜索到的结果编码:");
for(int value: code){
System.out.print("" + value);
}
System.out.println();
for (int i = 0, k = 1; i < code.length; i += 2, k++) {
direction = binaryArrayToNum(new int[] { code[i], code[i + 1] });
tempX += MAZE_DIRECTION_CHANGE[direction][0];
tempY += MAZE_DIRECTION_CHANGE[direction][1];
System.out.println(MessageFormat.format(
"第{0}步,编码为{1}{2},向{3}移动,移动后到达({4},{5})", k, code[i], code[i+1],
MAZE_DIRECTION_LABEL[direction], tempX, tempY));
}
}
}
算法的调用类Client.java:
package GA_Maze;
/**
* 遗传算法在走迷宫游戏的应用
* @author lyq
*
*/
public class Client {
public static void main(String[] args) {
//迷宫地图文件数据地址
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\mapData.txt";
//初始个体数量
int initSetsNum = 4;
GATool tool = new GATool(filePath, initSetsNum);
tool.goOutMaze();
}
}
算法的输出:
我测了很多次的数据,因为有可能会一时半会搜索不出来,我设置了最大遗传次数100次。
总共遗传进化了2次
起始点位置(4,4), 出口点位置(1, 0)
搜索到的结果编码:10100000100010
第1步,编码为10,向左移动,移动后到达(4,3)
第2步,编码为10,向左移动,移动后到达(4,2)
第3步,编码为00,向上移动,移动后到达(3,2)
第4步,编码为00,向上移动,移动后到达(2,2)
第5步,编码为10,向左移动,移动后到达(2,1)
第6步,编码为00,向上移动,移动后到达(1,1)
第7步,编码为10,向左移动,移动后到达(1,0)
总共遗传进化了8次
起始点位置(4,4), 出口点位置(1, 0)
搜索到的结果编码:10001000101000
第1步,编码为10,向左移动,移动后到达(4,3)
第2步,编码为00,向上移动,移动后到达(3,3)
第3步,编码为10,向左移动,移动后到达(3,2)
第4步,编码为00,向上移动,移动后到达(2,2)
第5步,编码为10,向左移动,移动后到达(2,1)
第6步,编码为10,向左移动,移动后到达(2,0)
第7步,编码为00,向上移动,移动后到达(1,0)
总共遗传进化了100次
在有限的遗传进化次数内,没有找到最优路径
算法小结
遗传算法在走迷宫中的应用总体而言还是非常有意思的如果你去认真的体会的话,至少让我更加深入的理解了GA算法,如果博友向要亲自实现这算法,我给几点建议,第一是迷宫难度的和初始个体数量的设置,为什么要注意这2点呢,一个是这关系到遗传迭代的次数,在一段时间内有的时候遗传算法是找不出来的,如果找不出来,PC机的CPU会持续高速的计算,所以不要让遗传进行无限制的进行,最好做点次数限制,也可能是我的本本配置太烂了。。在算法的调试中修复了一个之前没发现的bug,就是选择阶段的时候对于随机数的判断少考虑了一种情形,当随机数取到1.0的时候,其实是不能判断到的,因为概念和只会无限接近1,就不知道被划分到哪个区域中了。
作者:Androidlushangderen 发表于2015/3/26 21:56:15 原文链接
阅读:1128 评论:0 查看评论
Chameleon两阶段聚类算法
2015年3月23日 20:43
参考文献:http://www.cnblogs.com/zhangchaoyang/articles/2182752.html(用了很多的图和思想)
博客园(华夏35度) 作者:Orisun
数据挖掘算法-Chameleon算法.百度文库
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib(里面可能有你正想要的算法)
算法介绍
本篇文章讲述的还是聚类算法,也是属于层次聚类算法领域的,不过与上篇文章讲述的分裂实现聚类的方式不同,这次所讲的Chameleon算法是合并形成最终的聚类,恰巧相反。Chamelon的英文单词的意思是变色龙,所以这个算法又称之为变色龙算法,变色龙算法的过程如标题所描绘的那样,是分为2个主要阶段的,不过他可不是像BIRCH算法那样,是树的形式。继续看下面的原理介绍。
算法原理
先来张图来大致了解整个算法的过程。
上面图的显示过程虽然说有3个阶段,但是这其中概况起来就是两个阶段,第一个是形成小簇集的过程就是从Data Set 到k最近邻图到分裂成小聚餐,第二个阶段是合并这些小聚簇形成最终的结果聚簇。理解了算法的大致过程,下面看看里面定义的一些概念,还不少的样子。
为了引出变色龙算法的一些定义,这里先说一下以往的一些聚类算法的不足之处。
1、忽略簇与簇之间的互连性。就会导致最终的结果形成如下:
2、忽略簇与簇之间的近似性。就会导致最终的聚类结果变成这样“:
为什么提这些呢,因为Chameleon算法正好弥补了这2点要求,兼具互连性和近似性。在Chameleon算法中定义了相对互连性,RI表示和相对近似性,RC表示,最后通过一个度量函数:
function value = RI( Ci, Cj)× RC( Ci, Cj)α,α在这里表示的多少次方的意思,不是乘法。
来作为2个簇是否能够合并的标准,其实这些都是第二阶段做的事情了。
在第一阶段,所做的一件关键的事情就是形成小簇集,由零星的几个数据点连成小簇,官方的作法是用hMetic算法根据最小化截断的边的权重和来分割k-最近邻图,然后我网上找了一些资料,没有确切的hMetic算法,借鉴了网上其他人的一些办法,于是用了一个很简单的思路,就是给定一个点,把他离他最近的k个点连接起来,就算是最小簇了。事实证明,效果也不会太差,最近的点的换一个意思就是与其最大权重的边,采用距离的倒数最为权重的大小。因为后面的计算,用到的会是权重而不是距离。
我们再回过头来细说第二阶段所做的事情,首先是2个略复杂的公式(直接采用截图的方式):
相对互连性RI=
相对近似性RC=
Ci,Cj表示的是i,j聚簇内的数据点的个数,EC(Ci)表示的Ci聚簇内的边的权重和,EC(Ci,Cj)表示的是连接2个聚簇的边的权重和。
后来我在查阅书籍和一些文库的时候发现,这个公式还不是那么的标准,因为他对分母,分子进行了部分的改变,但是大意上还是一致的,标准公式上用到的是平均权重,而这里用的是和的形式,差别不大,所以就用这个公式了。
那么合并的过程如下:
1、给定度量函数如下minMetric,
2、访问每个簇,计算他与邻近的每个簇的RC和RI,通过度量函数公式计算出值tempMetric。
3、找到最大的tempMetric,如果最大的tempMetric超过阈值minMetric,将簇与此值对应的簇合并
4、如果找到的最大的tempMetric没有超过阈值,则表明此聚簇已合并完成,移除聚簇列表,加入到结果聚簇中。
4、递归步骤2,直到待合并聚簇列表最终大小为空。
算法的实现
算法的输入依旧采用的是坐标点的形式graphData.txt:
0 2 2
1 3 1
2 3 4
3 3 14
4 5 3
5 8 3
6 8 6
7 9 8
8 10 4
9 10 7
10 10 10
11 10 14
12 11 13
13 12 8
14 12 15
15 14 7
16 14 9
17 14 15
18 15 8
算法坐标点数据Point.java:
package DataMining_Chameleon;
/**
* 坐标点类
* @author lyq
*
*/
public class Point{
//坐标点id号,id号唯一
int id;
//坐标横坐标
Integer x;
//坐标纵坐标
Integer y;
//是否已经被访问过
boolean isVisited;
public Point(String id, String x, String y){
this.id = Integer.parseInt(id);
this.x = Integer.parseInt(x);
this.y = Integer.parseInt(y);
}
/**
* 计算当前点与制定点之间的欧式距离
*
* @param p
* 待计算聚类的p点
* @return
*/
public double ouDistance(Point p) {
double distance = 0;
distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
* (this.y - p.y);
distance = Math.sqrt(distance);
return distance;
}
/**
* 判断2个坐标点是否为用个坐标点
*
* @param p
* 待比较坐标点
* @return
*/
public boolean isTheSame(Point p) {
boolean isSamed = false;
if (this.x == p.x && this.y == p.y) {
isSamed = true;
}
return isSamed;
}
}
簇类Cluster.java:
package DataMining_Chameleon;
import java.util.ArrayList;
/**
* 聚簇类
*
* @author lyq
*
*/
public class Cluster implements Cloneable{
//簇唯一id标识号
int id;
// 聚簇内的坐标点集合
ArrayList points;
// 聚簇内的所有边的权重和
double weightSum = 0;
public Cluster(int id, ArrayList points) {
this.id = id;
this.points = points;
}
/**
* 计算聚簇的内部的边权重和
*
* @return
*/
public double calEC() {
int id1 = 0;
int id2 = 0;
weightSum = 0;
for (Point p1 : points) {
for (Point p2 : points) {
id1 = p1.id;
id2 = p2.id;
// 为了避免重复计算,取id1小的对应大的
if (id1 < id2 && ChameleonTool.edges[id1][id2] == 1) {
weightSum += ChameleonTool.weights[id1][id2];
}
}
}
return weightSum;
}
/**
* 计算2个簇之间最近的n条边
*
* @param otherCluster
* 待比较的簇
* @param n
* 最近的边的数目
* @return
*/
public ArrayList calNearestEdge(Cluster otherCluster, int n){
int count = 0;
double distance = 0;
double minDistance = Integer.MAX_VALUE;
Point point1 = null;
Point point2 = null;
ArrayList edgeList = new ArrayList<>();
ArrayList pointList1 = (ArrayList) points.clone();
ArrayList pointList2 = null;
Cluster c2 = null;
try {
c2 = (Cluster) otherCluster.clone();
pointList2 = c2.points;
} catch (CloneNotSupportedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
int[] tempEdge;
// 循环计算出每次的最近距离
while (count < n) {
tempEdge = new int[2];
minDistance = Integer.MAX_VALUE;
for (Point p1 : pointList1) {
for (Point p2 : pointList2) {
distance = p1.ouDistance(p2);
if (distance < minDistance) {
point1 = p1;
point2 = p2;
tempEdge[0] = p1.id;
tempEdge[1] = p2.id;
minDistance = distance;
}
}
}
pointList1.remove(point1);
pointList2.remove(point2);
edgeList.add(tempEdge);
count++;
}
return edgeList;
}
@Override
protected Object clone() throws CloneNotSupportedException {
// TODO Auto-generated method stub
//引用需要再次复制,实现深拷贝
ArrayList pointList = (ArrayList) this.points.clone();
Cluster cluster = new Cluster(id, pointList);
return cluster;
}
}
算法工具类Chameleon.java:
package DataMining_Chameleon;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
/**
* Chameleon 两阶段聚类算法工具类
*
* @author lyq
*
*/
public class ChameleonTool {
// 测试数据点文件地址
private String filePath;
// 第一阶段的k近邻的k大小
private int k;
// 簇度量函数阈值
private double minMetric;
// 总的坐标点的个数
private int pointNum;
// 总的连接矩阵的情况,括号表示的是坐标点的id号
public static int[][] edges;
// 点与点之间的边的权重
public static double[][] weights;
// 原始坐标点数据
private ArrayList totalPoints;
// 第一阶段产生的所有的连通子图作为最初始的聚类
private ArrayList initClusters;
// 结果簇结合
private ArrayList resultClusters;
public ChameleonTool(String filePath, int k, double minMetric) {
this.filePath = filePath;
this.k = k;
this.minMetric = minMetric;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
Point p;
totalPoints = new ArrayList<>();
for (String[] array : dataArray) {
p = new Point(array[0], array[1], array[2]);
totalPoints.add(p);
}
pointNum = totalPoints.size();
}
/**
* 递归的合并小聚簇
*/
private void combineSubClusters() {
Cluster cluster = null;
resultClusters = new ArrayList<>();
// 当最后的聚簇只剩下一个的时候,则退出循环
while (initClusters.size() > 1) {
cluster = initClusters.get(0);
combineAndRemove(cluster, initClusters);
}
}
/**
* 递归的合并聚簇和移除聚簇
*
* @param clusterList
*/
private ArrayList combineAndRemove(Cluster cluster,
ArrayList clusterList) {
ArrayList remainClusters;
double metric = 0;
double maxMetric = -Integer.MAX_VALUE;
Cluster cluster1 = null;
Cluster cluster2 = null;
for (Cluster c2 : clusterList) {
if (cluster.id == c2.id) {
continue;
}
metric = calMetricfunction(cluster, c2, 1);
if (metric > maxMetric) {
maxMetric = metric;
cluster1 = cluster;
cluster2 = c2;
}
}
// 如果度量函数值超过阈值,则进行合并,继续搜寻可以合并的簇
if (maxMetric > minMetric) {
clusterList.remove(cluster2);
// 将边进行连接
connectClusterToCluster(cluster1, cluster2);
// 将簇1和簇2合并
cluster1.points.addAll(cluster2.points);
remainClusters = combineAndRemove(cluster1, clusterList);
} else {
clusterList.remove(cluster);
remainClusters = clusterList;
resultClusters.add(cluster);
}
return remainClusters;
}
/**
* 将2个簇进行边的连接
*
* @param c1
* 聚簇1
* @param c2
* 聚簇2
*/
private void connectClusterToCluster(Cluster c1, Cluster c2) {
ArrayList connectedEdges;
connectedEdges = c1.calNearestEdge(c2, 2);
for (int[] array : connectedEdges) {
edges[array[0]][array[1]] = 1;
edges[array[1]][array[0]] = 1;
}
}
/**
* 算法第一阶段形成局部的连通图
*/
private void connectedGraph() {
double distance = 0;
Point p1;
Point p2;
// 初始化权重矩阵和连接矩阵
weights = new double[pointNum][pointNum];
edges = new int[pointNum][pointNum];
for (int i = 0; i < pointNum; i++) {
for (int j = 0; j < pointNum; j++) {
p1 = totalPoints.get(i);
p2 = totalPoints.get(j);
distance = p1.ouDistance(p2);
if (distance == 0) {
// 如果点为自身的话,则权重设置为0
weights[i][j] = 0;
} else {
// 边的权重采用的值为距离的倒数,距离越近,权重越大
weights[i][j] = 1.0 / distance;
}
}
}
double[] tempWeight;
int[] ids;
int id1 = 0;
int id2 = 0;
// 对每个id坐标点,取其权重前k个最大的点进行相连
for (int i = 0; i < pointNum; i++) {
tempWeight = weights[i];
// 进行排序
ids = sortWeightArray(tempWeight);
// 取出前k个权重最大的边进行连接
for (int j = 0; j < ids.length; j++) {
if (j < k) {
id1 = i;
id2 = ids[j];
edges[id1][id2] = 1;
edges[id2][id1] = 1;
}
}
}
}
/**
* 权重的冒泡算法排序
*
* @param array
* 待排序数组
*/
private int[] sortWeightArray(double[] array) {
double[] copyArray = array.clone();
int[] ids = null;
int k = 0;
double maxWeight = -1;
ids = new int[pointNum];
for (int i = 0; i < pointNum; i++) {
maxWeight = -1;
for (int j = 0; j < copyArray.length; j++) {
if (copyArray[j] > maxWeight) {
maxWeight = copyArray[j];
k = j;
}
}
ids[i] = k;
// 将当前找到的最大的值重置为-1代表已经找到过了
copyArray[k] = -1;
}
return ids;
}
/**
* 根据边的连通性去深度优先搜索所有的小聚簇
*/
private void searchSmallCluster() {
int currentId = 0;
Point p;
Cluster cluster;
initClusters = new ArrayList<>();
ArrayList pointList = null;
// 以id的方式逐个去dfs搜索
for (int i = 0; i < pointNum; i++) {
p = totalPoints.get(i);
if (p.isVisited) {
continue;
}
pointList = new ArrayList<>();
pointList.add(p);
recusiveDfsSearch(p, -1, pointList);
cluster = new Cluster(currentId, pointList);
initClusters.add(cluster);
currentId++;
}
}
/**
* 深度优先的方式找到边所连接着的所有坐标点
*
* @param p
* 当前搜索的起点
* @param lastId
* 此点的父坐标点
* @param pList
* 坐标点列表
*/
private void recusiveDfsSearch(Point p, int parentId, ArrayList pList) {
int id1 = 0;
int id2 = 0;
Point newPoint;
if (p.isVisited) {
return;
}
p.isVisited = true;
for (int j = 0; j < pointNum; j++) {
id1 = p.id;
id2 = j;
if (edges[id1][id2] == 1 && id2 != parentId) {
newPoint = totalPoints.get(j);
pList.add(newPoint);
// 以此点为起点,继续递归搜索
recusiveDfsSearch(newPoint, id1, pList);
}
}
}
/**
* 计算连接2个簇的边的权重
*
* @param c1
* 聚簇1
* @param c2
* 聚簇2
* @return
*/
private double calEC(Cluster c1, Cluster c2) {
double resultEC = 0;
ArrayList connectedEdges = null;
connectedEdges = c1.calNearestEdge(c2, 2);
// 计算连接2部分的边的权重和
for (int[] array : connectedEdges) {
resultEC += weights[array[0]][array[1]];
}
return resultEC;
}
/**
* 计算2个簇的相对互连性
*
* @param c1
* @param c2
* @return
*/
private double calRI(Cluster c1, Cluster c2) {
double RI = 0;
double EC1 = 0;
double EC2 = 0;
double EC1To2 = 0;
EC1 = c1.calEC();
EC2 = c2.calEC();
EC1To2 = calEC(c1, c2);
RI = 2 * EC1To2 / (EC1 + EC2);
return RI;
}
/**
* 计算簇的相对近似度
*
* @param c1
* 簇1
* @param c2
* 簇2
* @return
*/
private double calRC(Cluster c1, Cluster c2) {
double RC = 0;
double EC1 = 0;
double EC2 = 0;
double EC1To2 = 0;
int pNum1 = c1.points.size();
int pNum2 = c2.points.size();
EC1 = c1.calEC();
EC2 = c2.calEC();
EC1To2 = calEC(c1, c2);
RC = EC1To2 * (pNum1 + pNum2) / (pNum2 * EC1 + pNum1 * EC2);
return RC;
}
/**
* 计算度量函数的值
*
* @param c1
* 簇1
* @param c2
* 簇2
* @param alpha
* 幂的参数值
* @return
*/
private double calMetricfunction(Cluster c1, Cluster c2, int alpha) {
// 度量函数值
double metricValue = 0;
double RI = 0;
double RC = 0;
RI = calRI(c1, c2);
RC = calRC(c1, c2);
// 如果alpha大于1,则更重视相对近似性,如果alpha逍遥于1,注重相对互连性
metricValue = RI * Math.pow(RC, alpha);
return metricValue;
}
/**
* 输出聚簇列
*
* @param clusterList
* 输出聚簇列
*/
private void printClusters(ArrayList clusterList) {
int i = 1;
for (Cluster cluster : clusterList) {
System.out.print("聚簇" + i + ":");
for (Point p : cluster.points) {
System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y));
}
System.out.println();
i++;
}
}
/**
* 创建聚簇
*/
public void buildCluster() {
// 第一阶段形成小聚簇
connectedGraph();
searchSmallCluster();
System.out.println("第一阶段形成的小簇集合:");
printClusters(initClusters);
// 第二阶段根据RI和RC的值合并小聚簇形成最终结果聚簇
combineSubClusters();
System.out.println("最终的聚簇集合:");
printClusters(resultClusters);
}
}
调用类Client.java:
package DataMining_Chameleon;
/**
* Chameleon(变色龙)两阶段聚类算法
* @author lyq
*
*/
public class Client {
public static void main(String[] args){
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\graphData.txt";
//k-近邻的k设置
int k = 1;
//度量函数阈值
double minMetric = 0.1;
ChameleonTool tool = new ChameleonTool(filePath, k, minMetric);
tool.buildCluster();
}
}
算法输出如下:
第一阶段形成的小簇集合:
聚簇1:(2, 2) (3, 1) (3, 4) (5, 3)
聚簇2:(3, 14) (10, 14) (11, 13)
聚簇3:(8, 3) (10, 4)
聚簇4:(8, 6) (9, 8) (10, 7) (12, 8) (10, 10)
聚簇5:(12, 15) (14, 15)
聚簇6:(14, 7) (15, 8) (14, 9)
最终的聚簇集合:
聚簇1:(2, 2) (3, 1) (3, 4) (5, 3) (8, 3) (10, 4)
聚簇2:(3, 14) (10, 14) (11, 13) (12, 15) (14, 15)
聚簇3:(8, 6) (9, 8) (10, 7) (12, 8) (10, 10) (14, 7) (15, 8) (14, 9)
图形展示情况如下:
首先是第一阶段形成小簇集的结果:
然后是第二阶段合并的结果:
与结果相对应,请读者细细比较。
算法总结
在算法的实现过程中遇到一个比较大的困惑点在于2个簇近和并的时候,合并边的选取,我是直接采用的是最近的2对顶点进行连接,显然这是不合理的,当簇与簇规模比较大的时候,这个连接边需要变多,我有想过做一个计算函数,帮我计算估计要连接几条边。这里再提几点变色龙算法的优缺点,首先是这个算法将互连性和近似性都考虑了进来,其次他能发现高质量的任意形状的簇,问题有,第一与KNN算法一样,这个k的取值永远是一个痛,时间复杂度高,有可能会达到O(n*n)的程度,细心的博友一定能观察到我好多地方用到了双次循环的操作了。
作者:Androidlushangderen 发表于2015/3/23 20:43:37 原文链接
阅读:829 评论:0 查看评论
dbscan基于密度的空间聚类算法
2015年3月16日 20:24
参考文献:百度百科 http://baike.baidu.com
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
算法介绍
说到聚类算法,大家如果有看过我写的一些关于机器学习的算法文章,一定都这类算法不会陌生,之前将的是划分算法(K均值算法)和层次聚类算法(BIRCH算法),各有优缺点和好坏。本文所述的算法是另外一类的聚类算法,他能够克服BIRCH算法对于形状的限制,因为BIRCH算法偏向于聚簇球形的聚类形成,而dbscan采用的是基于空间的密度的原理,所以可以适用于任何形状的数据聚类实现。
算法原理
在介绍算法原理之前,先介绍几个dbscan算法中的几个概念定义:
Ε领域:给定对象半径为Ε内的区域称为该对象的Ε领域;
核心对象:如果给定对象Ε领域内的样本点数大于等于MinPts,则称该对象为核心对象;
直接密度可达:对于样本集合D,如果样本点q在p的Ε领域内,并且p为核心对象,那么对象q从对象p直接密度可达。
密度可达:对于样本集合D,给定一串样本点p1,p2….pn,p= p1,q= pn,假如对象pi从pi-1直接密度可达,那么对象q从对象p密度可达。
密度相连:存在样本集合D中的一点o,如果对象o到对象p和对象q都是密度可达的,那么p和q密度相联。
下面是算法的过程(可能说的不是很清楚):
1、扫描原始数据,获取所有的数据点。
2、遍历数据点中的每个点,如果此点已经被访问(处理)过,则跳过,否则取出此点做聚类查找。
3、以步骤2中找到的点P为核心对象,找出在E领域内所有满足条件的点,如果个数大于等于MinPts,则此点为核心对象,加入到簇中。
4、再次P为核心对象的簇中的每个点,进行递归的扩增簇。如果P点的递归扩增结束,再次回到步骤2。
5、算法的终止条件为所有的点都被访问(处理过)。
算法可以理解为是一个DFS的深度优先扩展。
算法的实现
算法的输入Input(格式(x, y)):
2 2
3 1
3 4
3 14
5 3
8 3
8 6
9 8
10 4
10 7
10 10
10 14
11 13
12 8
12 15
14 7
14 9
14 15
15 8
坐标点类Point.java:
package DataMining_DBSCAN;
/**
* 坐标点类
*
* @author lyq
*
*/
public class Point {
// 坐标点横坐标
int x;
// 坐标点纵坐标
int y;
// 此节点是否已经被访问过
boolean isVisited;
public Point(String x, String y) {
this.x = (Integer.parseInt(x));
this.y = (Integer.parseInt(y));
this.isVisited = false;
}
/**
* 计算当前点与制定点之间的欧式距离
*
* @param p
* 待计算聚类的p点
* @return
*/
public double ouDistance(Point p) {
double distance = 0;
distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
* (this.y - p.y);
distance = Math.sqrt(distance);
return distance;
}
/**
* 判断2个坐标点是否为用个坐标点
*
* @param p
* 待比较坐标点
* @return
*/
public boolean isTheSame(Point p) {
boolean isSamed = false;
if (this.x == p.x && this.y == p.y) {
isSamed = true;
}
return isSamed;
}
}
算法工具类DNSCANTool.java:
package DataMining_DBSCAN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
/**
* DBSCAN基于密度聚类算法工具类
*
* @author lyq
*
*/
public class DBSCANTool {
// 测试数据文件地址
private String filePath;
// 簇扫描半径
private double eps;
// 最小包含点数阈值
private int minPts;
// 所有的数据坐标点
private ArrayList totalPoints;
// 聚簇结果
private ArrayList> resultClusters;
//噪声数据
private ArrayList noisePoint;
public DBSCANTool(String filePath, double eps, int minPts) {
this.filePath = filePath;
this.eps = eps;
this.minPts = minPts;
readDataFile();
}
/**
* 从文件中读取数据
*/
public void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
Point p;
totalPoints = new ArrayList<>();
for (String[] array : dataArray) {
p = new Point(array[0], array[1]);
totalPoints.add(p);
}
}
/**
* 递归的寻找聚簇
*
* @param pointList
* 当前的点列表
* @param parentCluster
* 父聚簇
*/
private void recursiveCluster(Point point, ArrayList parentCluster) {
double distance = 0;
ArrayList cluster;
// 如果已经访问过了,则跳过
if (point.isVisited) {
return;
}
point.isVisited = true;
cluster = new ArrayList<>();
for (Point p2 : totalPoints) {
// 过滤掉自身的坐标点
if (point.isTheSame(p2)) {
continue;
}
distance = point.ouDistance(p2);
if (distance <= eps) {
// 如果聚类小于给定的半径,则加入簇中
cluster.add(p2);
}
}
if (cluster.size() >= minPts) {
// 将自己也加入到聚簇中
cluster.add(point);
// 如果附近的节点个数超过最下值,则加入到父聚簇中,同时去除重复的点
addCluster(parentCluster, cluster);
for (Point p : cluster) {
recursiveCluster(p, parentCluster);
}
}
}
/**
* 往父聚簇中添加局部簇坐标点
*
* @param parentCluster
* 原始父聚簇坐标点
* @param cluster
* 待合并的聚簇
*/
private void addCluster(ArrayList parentCluster,
ArrayList cluster) {
boolean isCotained = false;
ArrayList addPoints = new ArrayList<>();
for (Point p : cluster) {
isCotained = false;
for (Point p2 : parentCluster) {
if (p.isTheSame(p2)) {
isCotained = true;
break;
}
}
if (!isCotained) {
addPoints.add(p);
}
}
parentCluster.addAll(addPoints);
}
/**
* dbScan算法基于密度的聚类
*/
public void dbScanCluster() {
ArrayList cluster = null;
resultClusters = new ArrayList<>();
noisePoint = new ArrayList<>();
for (Point p : totalPoints) {
if(p.isVisited){
continue;
}
cluster = new ArrayList<>();
recursiveCluster(p, cluster);
if (cluster.size() > 0) {
resultClusters.add(cluster);
}else{
noisePoint.add(p);
}
}
removeFalseNoise();
printClusters();
}
/**
* 移除被错误分类的噪声点数据
*/
private void removeFalseNoise(){
ArrayList totalCluster = new ArrayList<>();
ArrayList deletePoints = new ArrayList<>();
//将聚簇合并
for(ArrayList list: resultClusters){
totalCluster.addAll(list);
}
for(Point p: noisePoint){
for(Point p2: totalCluster){
if(p2.isTheSame(p)){
deletePoints.add(p);
}
}
}
noisePoint.removeAll(deletePoints);
}
/**
* 输出聚类结果
*/
private void printClusters() {
int i = 1;
for (ArrayList pList : resultClusters) {
System.out.print("聚簇" + (i++) + ":");
for (Point p : pList) {
System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y));
}
System.out.println();
}
System.out.println();
System.out.print("噪声数据:");
for (Point p : noisePoint) {
System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y));
}
System.out.println();
}
}
测试类Client.java:
package DataMining_DBSCAN;
/**
* Dbscan基于密度的聚类算法测试类
* @author lyq
*
*/
public class Client {
public static void main(String[] args){
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
//簇扫描半径
double eps = 3;
//最小包含点数阈值
int minPts = 3;
DBSCANTool tool = new DBSCANTool(filePath, eps, minPts);
tool.dbScanCluster();
}
}
算法的输出:
聚簇1:(2,2) (3,4) (5,3) (3,1) (8,3) (8,6) (10,4) (9,8) (10,7) (10,10) (12,8) (14,7) (14,9) (15,8)
聚簇2:(10,14) (11,13) (14,15) (12,15)
噪声数据:(3,14)
图示结果如下:
算法的缺点
dbscan虽说可以用于任何形状的聚类发现,但是对于密度分布不均衡的数据,变化比较大,分类的性能就不会特别好,还有1点是不能反映高尺寸数据。
作者:Androidlushangderen 发表于2015/3/16 20:24:44 原文链接
阅读:762 评论:2 查看评论
Genetic Algorithm遗传算法学习
2015年3月3日 18:34
参考资料:http://blog.csdn.net/b2b160/article/details/4680853/#comments(冒昧的用了链接下的几张图)
百度百科:http://baike.baidu.com/link?url=FcwTBx_yPcD5DDEnN1FqvTkG4QNllkB7Yis6qFOL65wpn6EdT5LXFxUCmv4JlUfV3LUPHQGdYbGj8kHVs3GuaK
算法介绍
遗传算法是模拟达尔文生物进化论的自然选择和遗传学进化机理的计算模型。运用到了生物学中“适者生存,优胜劣汰”的原理。在每一次的进化过程中,把拥有更好环境适应性的基因传给下一代,直到最后的个体满足特定的条件,代表进化的结束,GA(后面都以GA代称为遗传算法的意思)算法是一种利用生物进化理论来搜索最优解的一种算法。
算法原理
算法的基本框架
了解算法的基本框架是理解整个算法的基础,算法的框架包含编码、适应度函数、初始群体的选择。先假设本例子的目标函数如下,求出他的最大值
f(x) = x1 * x1 + x2 * x2; 1<= x1 <=7, 1<= x2<=7
1、适应度函数。适应度函数是用来计算个体的适应值计算,顾名思义,适应值越高的个体,环境适应性越好,自然就要有更多的机会把自己的基因传给下一代,所以,其实适应度函数在这里起着一个过滤条件的作用。在本例子,目标函数总为非负值,并且以函数最大化为优化目标,所以可以将函数的值作为适应值。
2、编码。编码指的是将个体的信息表示成编码的字符串形式。如果个体变量时数字的形式,可以转为二进制的方式。
算法的群体选择过程
这个过程是遗传算法的核心过程,在里面分为了3个小的步骤,选择,交叉,变异。
1、初始个体的选择过程。就是挑选哪些个体作为即将产生下一代的个体呢。过程如下:
(1).利用适值函数,计算每个个体的适值,计算每个个体的适值占总和的百分比。
(2).根据百分比为每个个体划定一定的所属区间。
(3).产生一个[0, 1]的小数,判断这个小数点落在哪个个体的区间内,就表明要选出这个个体。这里其实就已经蕴含着把高适值的个体优先传入下一代,因为适值高,有更高的几率小数是落在自己的区间内的。
用图示范的形式表现如下:
2、交叉运算。个体的交叉运算过程的步骤细节如下:
(1).首先对于上个选择步骤选择来的个体进行随机的两两配对。
(2).取出其中配对的一对个体,随机设定一个交叉点,2个个体的编码的交叉点后的编码值进行对调,生成新的2个个体编码。
(3).所有的配对的个体都执行步骤(2)操作,最后加入到一个结果集合中。
交叉运算的方式又很多,上面用的方法是其中比较常用的单点交叉方式。
用图示范的形式表现如下:
3.变异运算。变异运算过程的步骤细节如下:
(1).遍历从交叉运算所得结果的结果集,取出集中一个个体编码,准备做变异操作
(2).产生随机的一个变异点位置。所选个体的变异点位置的值做变异操作,将他的值取为反向的值。
(3).将所有的交叉运算所得的结果集中的元素都执行步骤(2)操作。
用图示范的形式如下:
整个遗传算法的原理过程,用一个流程图的表现形式如下:
算法代码实现
算法代码的测试用例正如算法原理所举的一样,遗传进化的阈值条件为:个体中产生了使目标函数最大化值的个体,就是基因为111111。
GATool.java:
package GA;
import java.util.ArrayList;
import java.util.Random;
/**
* 遗传算法工具类
*
* @author lyq
*
*/
public class GATool {
// 变量最小值
private int minNum;
// 变量最大值
private int maxNum;
// 单个变量的编码位数
private int codeNum;
// 初始种群的数量
private int initSetsNum;
// 随机数生成器
private Random random;
// 初始群体
private ArrayList initSets;
public GATool(int minNum, int maxNum, int initSetsNum) {
this.minNum = minNum;
this.maxNum = maxNum;
this.initSetsNum = initSetsNum;
this.random = new Random();
produceInitSets();
}
/**
* 产生初始化群体
*/
private void produceInitSets() {
this.codeNum = 0;
int num = maxNum;
int[] array;
initSets = new ArrayList<>();
// 确定编码位数
while (num != 0) {
codeNum++;
num /= 2;
}
for (int i = 0; i < initSetsNum; i++) {
array = produceInitCode();
initSets.add(array);
}
}
/**
* 产生初始个体的编码
*
* @return
*/
private int[] produceInitCode() {
int num = 0;
int num2 = 0;
int[] tempArray;
int[] array1;
int[] array2;
tempArray = new int[2 * codeNum];
array1 = new int[codeNum];
array2 = new int[codeNum];
num = 0;
while (num < minNum || num > maxNum) {
num = random.nextInt(maxNum) + 1;
}
numToBinaryArray(array1, num);
while (num2 < minNum || num2 > maxNum) {
num2 = random.nextInt(maxNum) + 1;
}
numToBinaryArray(array2, num2);
// 组成总的编码
for (int i = 0, k = 0; i < tempArray.length; i++, k++) {
if (k < codeNum) {
tempArray[i] = array1[k];
} else {
tempArray[i] = array2[k - codeNum];
}
}
return tempArray;
}
/**
* 选择操作,把适值较高的个体优先遗传到下一代
*
* @param initCodes
* 初始个体编码
* @return
*/
private ArrayList selectOperate(ArrayList initCodes) {
double randomNum = 0;
double sumAdaptiveValue = 0;
ArrayList resultCodes = new ArrayList<>();
double[] adaptiveValue = new double[initSetsNum];
for (int i = 0; i < initSetsNum; i++) {
adaptiveValue[i] = calCodeAdaptiveValue(initCodes.get(i));
sumAdaptiveValue += adaptiveValue[i];
}
// 转成概率的形式,做归一化操作
for (int i = 0; i < initSetsNum; i++) {
adaptiveValue[i] = adaptiveValue[i] / sumAdaptiveValue;
}
for (int i = 0; i < initSetsNum; i++) {
randomNum = random.nextInt(100) + 1;
randomNum = randomNum / 100;
sumAdaptiveValue = 0;
// 确定区间
for (int j = 0; j < initSetsNum; j++) {
if (randomNum > sumAdaptiveValue
&& randomNum <= sumAdaptiveValue + adaptiveValue[j]) {
//采用拷贝的方式避免引用重复
resultCodes.add(initCodes.get(j).clone());
break;
} else {
sumAdaptiveValue += adaptiveValue[j];
}
}
}
return resultCodes;
}
/**
* 交叉运算
*
* @param selectedCodes
* 上步骤的选择后的编码
* @return
*/
private ArrayList crossOperate(ArrayList selectedCodes) {
int randomNum = 0;
// 交叉点
int crossPoint = 0;
ArrayList resultCodes = new ArrayList<>();
// 随机编码队列,进行随机交叉配对
ArrayList randomCodeSeqs = new ArrayList<>();
// 进行随机排序
while (selectedCodes.size() > 0) {
randomNum = random.nextInt(selectedCodes.size());
randomCodeSeqs.add(selectedCodes.get(randomNum));
selectedCodes.remove(randomNum);
}
int temp = 0;
int[] array1;
int[] array2;
// 进行两两交叉运算
for (int i = 1; i < randomCodeSeqs.size(); i++) {
if (i % 2 == 1) {
array1 = randomCodeSeqs.get(i - 1);
array2 = randomCodeSeqs.get(i);
crossPoint = random.nextInt(2 * codeNum - 1) + 1;
// 进行交叉点位置后的编码调换
for (int j = 0; j < 2 * codeNum; j++) {
if (j >= crossPoint) {
temp = array1[j];
array1[j] = array2[j];
array2[j] = temp;
}
}
// 加入到交叉运算结果中
resultCodes.add(array1);
resultCodes.add(array2);
}
}
return resultCodes;
}
/**
* 变异操作
*
* @param crossCodes
* 交叉运算后的结果
* @return
*/
private ArrayList variationOperate(ArrayList crossCodes) {
// 变异点
int variationPoint = 0;
ArrayList resultCodes = new ArrayList<>();