转载请注明:http://blog.csdn.net/HEL_WOR/article/details/50583923
一直想找个方法把分类的数据直观的展示出来,最近在Java上发现了类似Pyhton的MatPlotLib库的jar包,上周末在屋里把代码折腾出来了。
Kmeans算属于非监督的聚类算法。
监督学习的定义是通过对算法进行有正面影响和负面影响的训练,算法能够学习出一种模型,这个时候我们将测试数据输入这个模型后,模型能得出我们想要的结果,例如分类。举个例子,我们怎样训练自己家的狗狗要听话,我们通过狗狗做对的事情奖励它,做错了事情教育它,那么足够多的时间后我们就会觉得狗狗越来越听话了。这就属于监督学习。
非监督学习是我们无法确切的去训练算法认识到什么是positive的,什么是negative的,具体的学习过程只能由算法自己去描述。
Kmeans属于非监督学习,与SVM,Logistic,决策树等监督学习算法不同,我们不能如果训练SVM那样给出一系列的正反数据点去训练算法成为一种能分辨正负类点的模型。Kmeans需要自己去学习如果分类。所以分类的本身就只能来源于算法的自身描述。
这个算法是一个坐标上升的优化算法,有兴趣的话可以自己推导,公式会不停交替地对u和c做求导操作,当导数趋于0时结果趋于不变,最终算法收敛。
根据这个距离公式,我们可以根据距离来描述类别。设定k个中心点,根据所有数据点与k个中心点的距离,把与中心点距离最近的点划分为一类。这是对公式中u的优化,同时对每一个类型重新计算中心点,这是对公式中c的优化,到最后,算法收敛,分类完成。
算法在实现后,要注意:
1.类别的个数会对分类结果有很大影响。
2.初始中心点的选择不当会造成数据点分类错误。
关于初始中心点,在SVM看到过有自启发式的方式。不知道能不能用到Kmeans上。
测试数据集: testData
测试结果(JFreeChart实现图像展示):
Kmeans类(算法主逻辑):
package helWor.Share.KMeans;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class KMeans {
private static int ClusterNum;
private static ArrayList<ArrayList<double[]>> cluster;
private static double[][] center = new double[ClusterNum][2];
private static double[][] lastCenter = new double[ClusterNum][2];
private ArrayList<double[]> dataSet = new ArrayList<double[]>();
/* * 构造函数 */
public KMeans(int clusterNum)
{
ClusterNum = clusterNum;
}
/* * 主执行方法 */
public void ExecuteMethod() throws IOException
{
LoadDataSet();
initCenters();
do
{
initCluster();
AllocateCluster();
lastCenter = center;
setNewCenter();
}
while(this.IsCenterChanged(center));
}
/* * 获取簇 */
public ArrayList<ArrayList<double[]>> getCluster()
{
return cluster;
}
/* * 装载数据 */
private void LoadDataSet() throws IOException
{
String fileName = "E:" + File.separator + "testSet.txt";
FileReader fileReader = new FileReader(fileName);
BufferedReader bufferReader = new BufferedReader(fileReader);
String line = bufferReader.readLine();
while(!(line.isEmpty() || line == null))
{
double[] data = new double[2];
String[] input = line.split("\t");
data[0] = Double.parseDouble(input[0]);
data[1] = Double.parseDouble(input[1]);
line = bufferReader.readLine();
dataSet.add(data);
}
fileReader.close();
bufferReader.close();
}
/* * 判断簇中心点是否改变 作为算法结束条件 */
private boolean IsCenterChanged(double[][] center)
{
for(int i = 0; i < center.length; i++)
{
if(center[i][0] != lastCenter[i][0] || center[i][1] != lastCenter[i][1])
{
return true;
}
}
return false;
}
/* * 初始化簇中心 */
private void initCenters()
{
//// 中心点的设置会导致分类失败
center = new double[][]{{-1,-2},{-3,2},{2,4}};
}
/* * 初始化簇容器 */
private void initCluster()
{
ArrayList<ArrayList<double[]>> initCluster = new ArrayList<ArrayList<double[]>>();
for(int i = 0; i < ClusterNum; i++)
{
initCluster.add(new ArrayList<double[]>());
}
if(cluster != null)
{
cluster.clear();
}
cluster = initCluster;
}
/* * 计算欧式距离 用以根据距离完成分类 */
private double CalcDistance(double[] element, double[] center)
{
double x = element[0] - center[0];
double y = element[1] - center[1];
double z = x*x + y*y;
return (double)Math.sqrt(z);
}
/* * 获取这个节点属于哪个簇 */
private int getClusterIndex(double[] distance)
{
double minDistance = distance[0];
int clusterIndex = 0;
for(int i = 0; i < distance.length; i++)
{
if(distance[i] < minDistance)
{
minDistance = distance[i];
clusterIndex = i;
}
}
return clusterIndex;
}
/* * 分配簇 */
private void AllocateCluster()
{
double[] distance = new double[ClusterNum];
for(double[] data : dataSet)
{
for(int j = 0; j < ClusterNum; j++)
{
distance[j] = this.CalcDistance(data, center[j]);
}
int clusterIndex = this.getClusterIndex(distance);
/* * 如果用ArrayList<double[][]>来描述簇也是可行的 但是在这里会很不好处理 不可能为每一个簇都保存一个索引值 */
cluster.get(clusterIndex).add(data);
}
}
/* * 设置新的簇中心 */
private void setNewCenter()
{
center = new double[ClusterNum][2];
for(int i = 0; i < center.length; i++)
{
if(cluster.get(i).size() != 0)
{
double[] newCenter = new double[2];
for(int j = 0; j < cluster.get(i).size(); j++)
{
newCenter[0] += cluster.get(i).get(j)[0];
newCenter[1] += cluster.get(i).get(j)[1];
}
center[i][0] = newCenter[0]/cluster.get(i).size();
center[i][1] = newCenter[1]/cluster.get(i).size();
}
}
}
}
JfreeChartScatter:(图像展示)
package helWor.Share.KMeans;
import java.util.ArrayList;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.XYPlot;
import org.jfree.ui.RefineryUtilities;
import org.jfree.chart.ChartFactory;
import org.jfree.data.xy.DefaultXYDataset;
import java.awt.*;
import org.jfree.chart.renderer.xy.XYItemRenderer;
public class JfreeChartScatter {
/** The data. */
private double[][] data = new double[2][100];
/* * 默认构造函数 */
public JfreeChartScatter()
{
}
/* * 调用JfreeChart绘制散点图 */
public void Scatter(ArrayList<ArrayList<double[]>> cluster)
{
DefaultXYDataset xyDataset = new DefaultXYDataset();
for(int i = 0; i < cluster.size(); i++)
{
for(int j = 0; j < cluster.get(i).size(); j++)
{
this.data[0][j] = cluster.get(i).get(j)[0];
this.data[1][j] = cluster.get(i).get(j)[1];
}
/** 设置分类个数 **/
xyDataset.addSeries("Cluster" + (i + 1), this.data);
//**让指针指向另一个内纯区域 避免引用类型的数组造成加入的数据都是最后一个数组的数据**//
this.data = new double[2][100];;
}
JFreeChart jfree = ChartFactory.createScatterPlot("KMeans", "X", "Y", xyDataset);
XYPlot xyPlot = (XYPlot)jfree.getPlot();
/**设置数据点大小,颜色**/
XYItemRenderer renderer = xyPlot.getRenderer();
renderer.setSeriesOutlineStroke(0, new BasicStroke(0.05f));
renderer.setSeriesPaint(0, Color.BLUE);
renderer.setSeriesOutlineStroke(1, new BasicStroke(0.05f));
renderer.setSeriesPaint(1, Color.RED);
renderer.setSeriesOutlineStroke(2, new BasicStroke(0.05f));
renderer.setSeriesPaint(2, Color.ORANGE);
renderer.setSeriesVisible(0, true);
renderer.setSeriesVisible(1, true);
renderer.setSeriesVisible(2, true);
jfree.setBackgroundPaint(Color.WHITE);
ChartFrame frame = new ChartFrame("Kmeans",jfree);
frame.pack();
//**设置图像展示在屏幕正中**//
RefineryUtilities.centerFrameOnScreen(frame);
//**设置图像可见**//
frame.setVisible(true);
}
}
调用方法:
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
ApplicationContext context = new ClassPathXmlApplicationContext("application.xml");
KMeans kmeans = (KMeans) context.getBean("KMeans");
JfreeChartScatter scatter = (JfreeChartScatter)context.getBean("Scatter");
kmeans.ExecuteMethod();
scatter.Scatter(kmeans.getCluster());
}
xml(JDK8+Spring4):
<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd">
<bean id="Scatter" class="helWor.Share.KMeans.JfreeChartScatter"></bean>
<bean id="KMeans" class="helWor.Share.KMeans.KMeans">
<constructor-arg value="3" type="int"/>
</bean>
</beans>
谢谢阅读。