kmeans算法又称k均值算法,是一种聚类算法,属于无监督学习算法。
对于给定的样本集,kmeans将其中相似的样本成员分类组织到一起,最终将样本集划分成K个簇,每个簇内的样本成员相似度比较高。
使用K-Means算法对KDD CUP99网络入侵检测数据集进行聚类分析 。本程序先对输入数据集进行特征转换、归一化处理,然后基于flink通过kmeans将数据集聚成两类,实现对正常点和异常点的区分,用于检测入侵异常数据。
flink-1.9.1
1.随机选取K个聚类中心。(本例中两个,用于区分正常点和异常点)。
2.计算每个样本成员到聚类中心的距离,并将其分配到最近的聚类中。
3.计算每个聚类的样本均值,并将样本均值更新为新的聚类中心。
4.重复步骤2、3,直到聚类中心移动的距离小于给定阈值。
5.输出最终的聚类中心及其样本成员。
pointFile: 入侵检测数据点文件路径
outputPath: 结果输出目录
maxIterations: 算法最大迭代次数
disDiff: 迭代终止条件,即:每次迭代前后,簇中心的距离差
kNum: K值,即簇的个数。
注意: 如果没有指定将使用KMeansConstant类中的默认参数
多维数据点
package cn.xsy.algorithm.kmeans;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* 多维数据点
*/
public class Point implements Serializable {
//源特征值
public List<String> sourceFields;
//处理后的特征值
public List<Double> handledFields;
//特征最大值,用于归一化处理
public List<Double> fieldsMaxValue = new ArrayList<Double>(42);
//特征最小值,用于归一化处理
public List<Double> fieldsMinValue = new ArrayList<Double>(42);
//记录当前簇中数据点的个数,用于求新的簇类中心点的除法运算
public Long number = 1L;
public Point(){}
public Point(List<String> list){
sourceFields = list;
}
//将字符型特征转换为数值型特征
public Point featureHandled(){
handledFields = new ArrayList<Double>();
for (int i = 0; i < sourceFields.size(); i++) {
if(i == 1){
//协议类型特征转换
handledFields.add((double)(Arrays.asList(KMeansConstant.PROTOCOLS).indexOf(sourceFields.get(i))));
}else if(i == 2){
//网络服务类型特征转换
List<String> sercices = new ArrayList<String>(Arrays.asList(KMeansConstant.SERVICES));
int index = sercices.indexOf(sourceFields.get(i));
if(index == -1){
sercices.add(sourceFields.get(i));
handledFields.add((double) (sercices.indexOf(sourceFields.get(i))));
} else {
handledFields.add((double)(index));
}
}else if(i == 3){
//连接状态特征转换
handledFields.add((double)(Arrays.asList(KMeansConstant.FLAGS).indexOf(sourceFields.get(i))));
}else if(i == 41){
//标识类型特征转换
List<String> labels = new ArrayList<String>(Arrays.asList(KMeansConstant.LABELS));
int index = labels.indexOf(sourceFields.get(i));
if(index == -1){
labels.add(sourceFields.get(i));
handledFields.add((double) (labels.indexOf(sourceFields.get(i))));
} else {
handledFields.add((double)(index));
}
}else {
handledFields.add(Double.parseDouble(sourceFields.get(i)));
}
}
return this;
}
//求每一个特征的最大值和最小值
public Point MaxMinValue(Point point){
if(fieldsMaxValue.size() == 0){
fieldsMaxValue.addAll(handledFields);
}
if(fieldsMinValue.size() == 0){
fieldsMinValue.addAll(handledFields);
}
if(point.fieldsMaxValue.size() == 0){
point.fieldsMaxValue.addAll(point.handledFields);
}
if(point.fieldsMinValue.size() == 0){
point.fieldsMinValue.addAll(point.handledFields);
}
//求两个数据点各个特征值的最大值和最小值
for(int i = 0; i< handledFields.size(); i++){
if(point.fieldsMaxValue.get(i) > this.fieldsMaxValue.get(i)){
fieldsMaxValue.set(i,point.fieldsMaxValue.get(i));
}
if(point.fieldsMinValue.get(i) < this.fieldsMinValue.get(i)){
fieldsMinValue.set(i,point.fieldsMinValue.get(i));
}
}
return this;
}
//归一化
public Point standardHandled(Point point){
for(int i = 0; i< handledFields.size(); i++){
double max = point.fieldsMaxValue.get(i);
double min = point.fieldsMinValue.get(i);
double value = handledFields.get(i);
handledFields.set(i, max == min ? min : (value - min) / (max - min));
}
return this;
}
//加法器
public Point add(Point other){
//特征值相加
for (int i = 0; i < handledFields.size(); i++) {
handledFields.set(i,handledFields.get(i) + other.handledFields.get(i));
}
//数据点个数相加
number += other.number;
return this;
}
//除法器
public Point div(long val){
for (int i = 0; i < handledFields.size(); i++) {
handledFields.set(i,handledFields.get(i) / val);
}
return this;
}
//计算两点之间的欧式距离
public double euclideanDistance(Point other){
double sum = 0;
for (int i = 0; i < handledFields.size(); i++) {
sum += Math.pow((handledFields.get(i) - other.handledFields.get(i)),2);
}
return Math.sqrt(sum);
}
@Override
public String toString() {
return "Point{" +
"sourceFields=" + sourceFields +
'}';
}
}
聚类中心
package cn.xsy.algorithm.kmeans;
import java.io.Serializable;
/**
* 簇中心
*/
public class Cluster implements Serializable {
//簇id
public int id;
//簇中心点
public Point centre;
public Cluster(int id, Point centre) {
this.id = id;
this.centre = centre;
}
public Cluster() {
}
@Override
public String toString() {
return "Cluster{" +
"id=" + id +
", centre=" + centre +
'}';
}
}
kmeans常量
package cn.xsy.algorithm.kmeans;
public final class KMeansConstant {
//入侵检测数据点文件
//0,tcp,http,SF,228,896,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,23,24,0.00,0.00,0.00,0.00,1.00,0.00,0.08,255,255,1.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,normal.
public static final String POINTFILE = "C:\\Users\\xsy\\Desktop\\KDD99入侵检测数据集\\Test data with corrected labels\\corrected\\corrected";
//结果输出目录
public static final String OUTPUTPATH = "C:\\Users\\xsy\\Desktop\\KDD99入侵检测数据集\\output";
//最大迭代次数
public static final int MAXITERATIONS = 10;
//迭代终止条件,每次迭代前后,簇中心的距离差
public static final double DISDIFF = 1.0E-13;
//K值,即簇的个数
public static final int KNUM = 2;
/** 特征转换数据相关 **/
//协议类型
public static final String[] PROTOCOLS = new String[]{"tcp","udp","icmp"};
//目标主机的网络服务类型
public static final String[] SERVICES = new String[]{"aol","auth","bgp","courier","csnet_ns","ctf","daytime","discard","domain","domain_u",
"echo","eco_i","ecr_i","efs","exec","finger","ftp","ftp_data","gopher","harvest","hostnames",
"http","http_2784","http_443","http_8001","imap4","IRC","iso_tsap","klogin","kshell","ldap",
"link","login","mtp","name","netbios_dgm","netbios_ns","netbios_ssn","netstat","nnsp","nntp",
"ntp_u","other","pm_dump","pop_2","pop_3","printer","private","red_i","remote_job","rje","shell",
"smtp","sql_net","ssh","sunrpc","supdup","systat","telnet","tftp_u","tim_i","time","urh_i","urp_i",
"uucp","uucp_path","vmnet","whois","X11","Z39_50"};
//连接正常或错误的状态
public static final String[] FLAGS = new String[]{"OTH","REJ","RSTO","RSTOS0","RSTR","S0","S1","S2","S3","SF","SH"};
//标识类型
public static final String[] LABELS = new String[]{"normal.", "buffer_overflow.", "loadmodule.", "perl.", "neptune.", "smurf.",
"guess_passwd.", "pod.", "teardrop.", "portsweep.", "ipsweep.", "land.", "ftp_write.",
"back.", "imap.", "satan.", "phf.", "nmap.", "multihop.", "warezmaster.", "warezclient.",
"spy.", "rootkit.",
"mscan.", "saint.", "apache2.", "mailbomb.", "processtable.", "udpstorm.", "httptunnel.", "ps.",
"sqlattack.", "xterm.", "named.", "sendmail.", "snmpgetattack.", "snmpguess.", "worm.", "xlock.", "xsnoop."};
}
KMeans主程序入口
package cn.xsy.algorithm.kmeans;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.FileSystem;
import java.util.*;
/**
* KMeans主程序入口
*/
public class KMeans {
public static void main(String[] args) throws Exception {
//解析命令行参数
ParameterTool params = ParameterTool.fromArgs(args);
//构建执行环境
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
//使参数在web界面可用
env.getConfig().setGlobalJobParameters(params);
//从提供的文件路径读取数据点
DataSource<String> sourcePoints = getPointDataSet(params, env);
//从数据点中随机选取簇中心点
List<String> pointCollect = sourcePoints.collect();
Set<String> sourceClusterSet = getSourceClusterCollection(pointCollect, params);
DataSource<String> sourceCluster = env.fromCollection(sourceClusterSet);
//对数据点进行字符型特征装换
DataSet<Point> featurePoints = sourcePoints.map(new FeatureHandledPoint());
//对数据点求每一个特征的最大值和最小值
DataSet<Point> maxMinPoint = featurePoints.reduce(new MaxMinHandledPoint());
//对数据点进行归一化处理
DataSet<Point> points = featurePoints.map(new StandardHandledPoint()).withBroadcastSet(maxMinPoint, "maxMinPoint");
//对簇中心数据进行字符型特征装换以及归一化处理
DataSet<Cluster> clusters = sourceCluster.map(new HandledCluster()).withBroadcastSet(maxMinPoint, "maxMinPoint");
//设置KMeans的最大迭代次数
IterativeDataSet<Cluster> loop = clusters.iterate(params.getInt("maxIterations", KMeansConstant.MAXITERATIONS));
//KMeans迭代过程
DataSet<Cluster> newClusters = points
//将每个数据点分配到最近的簇中心
.map(new SelectNearestCluster()).withBroadcastSet(loop, "clusters")
//每个簇内的点坐标求和以及点个数求和
.groupBy(0).reduce(new ClusterAccumulator())
//计算新的簇中心
.map(new ClusterAverager());
//迭代终止条件
DataSet<Tuple2<Cluster, Cluster>> termination = loop
//将每次迭代前后的簇中心连接起来
.join(newClusters).where("id").equalTo("id")
//根据每次迭代前后簇中心的距离差过滤簇中心
.filter(new TerminationCriterion());
//将新的簇中心数据反馈到下一个迭代中
DataSet<Cluster> finalClusters = loop.closeWith(newClusters, termination);
//将point分派到最后的簇中
DataSet<Tuple2<Integer, Point>> clusterPoints = points.map(new SelectNearestCluster()).withBroadcastSet(finalClusters, "clusters");
//统计每一个簇中心每一个labels的point个数
// DataSet> clusterLabelsCount = clusterPoints.map(new CountClusterLabels()).groupBy(0,1).aggregate(Aggregations.SUM, 2);
//统计每一个簇中心的point个数
DataSet<Tuple2<Integer, Long>> clusterCount = clusterPoints.map(new CountCluster()).groupBy(0).aggregate(Aggregations.SUM, 1);
// 输出结果
String outputPath = params.has("outputPath") ? params.get("outputPath") : KMeansConstant.OUTPUTPATH;
clusterPoints.writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE);
clusterCount.print();
// env.execute("KDD CUP99 KMeans");
//一些统计结果以及算KMeans的Purity指数
JobExecutionResult lastJobExecutionResult = env.getLastJobExecutionResult();
double purity = getPurity(lastJobExecutionResult);
System.out.println("purity: " + purity);
}
/**
* 计算KMeans的Purity指数
*
* @param lastJobExecutionResult
* @return
*/
private static double getPurity(JobExecutionResult lastJobExecutionResult) {
//数据点总数
int pointCount = lastJobExecutionResult.getAccumulatorResult("pointCount");
//簇中心1的正常点
int cluster1Normal = lastJobExecutionResult.getAccumulatorResult("cluster1Normal");
//簇中心1的异常点
int cluster1Abnormal = lastJobExecutionResult.getAccumulatorResult("cluster1Abnormal");
//簇中心2的正常点
int cluster2Normal = lastJobExecutionResult.getAccumulatorResult("cluster2Normal");
//簇中心2的异常点
int cluster2Abnormal = lastJobExecutionResult.getAccumulatorResult("cluster2Abnormal");
double purity;
if(cluster1Abnormal > cluster2Abnormal){
purity = (double) (cluster1Abnormal + cluster2Normal) / pointCount;
} else if(cluster1Abnormal < cluster2Abnormal){
purity = (double) (cluster2Abnormal + cluster1Normal) / pointCount;
}else {
if(cluster1Normal > cluster2Normal){
purity = (double) (cluster2Abnormal + cluster1Normal) / pointCount;
}else {
purity = (double) (cluster1Abnormal + cluster2Normal) / pointCount;
}
}
System.out.println("数据点总个数: " + pointCount);
System.out.println("簇中心1正常点个数: " + cluster1Normal);
System.out.println("簇中心1异常点个数: " + cluster1Abnormal);
System.out.println("簇中心2正常点个数: " + cluster2Normal);
System.out.println("簇中心2异常点个数: " + cluster2Abnormal);
return purity;
}
/**
* 得到输入点数据集
*
* @param params
* @param env
* @return
*/
private static DataSource<String> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {
String pointFile = params.has("pointFile") ? params.get("pointFile") : KMeansConstant.POINTFILE;
DataSource<String> sourcePoints = env.readTextFile(pointFile);
return sourcePoints;
}
/**
* 从数据点中随机选取簇中心点,构建簇中心数据集
*
* @param sourcePointList
* @param params
* @return
*/
private static Set<String> getSourceClusterCollection(List<String> sourcePointList, ParameterTool params) {
int kNum = params.has("kNum") ? Integer.parseInt(params.get("kNum")) : KMeansConstant.KNUM;
Set<String> clusterSet = new HashSet<String>();
Random random = new Random();
for (int id = 1; id <= kNum; ) {
String point = sourcePointList.get(random.nextInt(sourcePointList.size()));
//用于标记是否已经选择过该数据
boolean flag =true;
for (String cluster : clusterSet) {
String[] split = cluster.split(" ");
if (split[0].equals(point)) {
flag = false;
}
}
//如果随机选取的点没有被选中过,则加入到SET中
if (flag) {
String cluster = point + " " + id;
clusterSet.add(cluster);
System.out.println("簇中心" + id + ": " + cluster);
id++;
}
}
return clusterSet;
}
/**
* 对数据点进行字符型特征装换
*/
public static final class FeatureHandledPoint implements MapFunction<String, Point> {
public Point map(String s) throws Exception {
String[] split = s.split(",");
Point point = new Point(Arrays.asList(split));
//字符型特征转换为数值型特征
Point featurePoint = point.featureHandled();
return featurePoint;
}
}
/**
* 对数据点求每个特征的最大值和最小值
*/
public static final class MaxMinHandledPoint implements ReduceFunction<Point> {
public Point reduce(Point p1, Point p2) throws Exception {
//求每一个特征的最大值和最小值
return p1.MaxMinValue(p2);
}
}
/**
* 对簇中心数据进行字符型特征装换、归一化处理
*/
public static final class HandledCluster extends RichMapFunction<String, Cluster> {
private List<Point> maxMinPoints;
@Override
public void open(Configuration parameters) throws Exception {
this.maxMinPoints = getRuntimeContext().getBroadcastVariable("maxMinPoint");
}
public Cluster map(String s) throws Exception {
String[] fields = s.split(" ");
String[] splits = fields[0].split(",");
Point centre = new Point(Arrays.asList(splits));
//字符型特征转换为数值型特征
Point featureCentre = centre.featureHandled();
//归一化
Point standardCentre = featureCentre.standardHandled(maxMinPoints.get(0));
return new Cluster(Integer.parseInt(fields[1]), standardCentre);
}
}
/**
* 对数据点进行归一化处理
* X(norm) = (X - min) / (max - min)
*/
public static final class StandardHandledPoint extends RichMapFunction<Point, Point> {
//point条数
private IntCounter pointCount = new IntCounter();
private List<Point> maxMinPoints;
@Override
public void open(Configuration parameters) throws Exception {
getRuntimeContext().addAccumulator("pointCount", pointCount);
this.maxMinPoints = getRuntimeContext().getBroadcastVariable("maxMinPoint");
}
public Point map(Point point) throws Exception {
//对每一个point进行归一化
Point standardPoint = point.standardHandled(maxMinPoints.get(0));
pointCount.add(1);
return standardPoint;
}
}
/**
* 对每一个数据点,找到距离最近的簇中心
*/
public static final class SelectNearestCluster extends RichMapFunction<Point, Tuple2<Integer, Point>> {
private Collection<Cluster> clusters;
@Override
public void open(Configuration parameters) throws Exception {
this.clusters = getRuntimeContext().getBroadcastVariable("clusters");
}
@Override
public Tuple2<Integer, Point> map(Point point) throws Exception {
double minDistance = Double.MAX_VALUE;
int closestClusterId = -1;
for (Cluster cluster : clusters) {
double distance = point.euclideanDistance(cluster.centre);
if (distance < minDistance) {
minDistance = distance;
closestClusterId = cluster.id;
}
}
return new Tuple2<Integer, Point>(closestClusterId, point);
}
}
/**
* 对每一个簇内点计数以及对簇内点的坐标进行累加
*/
public static final class ClusterAccumulator implements ReduceFunction<Tuple2<Integer, Point>> {
public Tuple2<Integer, Point> reduce(Tuple2<Integer, Point> val1, Tuple2<Integer, Point> val2) {
// 对簇内点坐标累加,然后对簇内元素个数计数
return new Tuple2<Integer, Point>(val1.f0, val1.f1.add(val2.f1));
}
}
/**
* 从簇内点的个数和这些点的坐标和计算出新的簇中心
*/
public static final class ClusterAverager implements MapFunction<Tuple2<Integer, Point>, Cluster> {
public Cluster map(Tuple2<Integer, Point> value) {
// 新的簇中心id和簇中心坐标
return new Cluster(value.f0, value.f1.div(value.f1.number));
}
}
/**
* 根据每次迭代前后簇中心的距离差过滤簇中心
*/
public static final class TerminationCriterion extends RichFilterFunction<Tuple2<Cluster, Cluster>> {
public boolean filter(Tuple2<Cluster, Cluster> value) throws Exception {
ParameterTool params = (ParameterTool) getRuntimeContext().getExecutionConfig().getGlobalJobParameters();
double disDiff = params.has("disDiff") ? Double.parseDouble(params.get("disDiff")) : KMeansConstant.DISDIFF;
double moveDistance = value.f0.centre.euclideanDistance(value.f1.centre);
System.out.println("簇中心" + value.f0.id + "移动距离: " + moveDistance);
return moveDistance > disDiff;
}
}
/**
* 将Tuple2转换为 Tuple3
*/
public static final class CountClusterLabels implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, String, Long>> {
public Tuple3<Integer, String, Long> map(Tuple2<Integer, Point> integerPointTuple2) throws Exception {
//对每一个簇,每一个LABELS的点进行计数
return new Tuple3<Integer, String, Long>(integerPointTuple2.f0, integerPointTuple2.f1.sourceFields.get(41), 1L);
}
}
/**
* 将Tuple2转换为 Tuple2
*/
public static final class CountCluster extends RichMapFunction<Tuple2<Integer, Point>, Tuple2<Integer, Long>> {
//簇中心1的正常点
private IntCounter cluster1Normal = new IntCounter();
//簇中心1的异常点
private IntCounter cluster1Abnormal = new IntCounter();
//簇中心2的正常点
private IntCounter cluster2Normal = new IntCounter();
//簇中心2的异常点
private IntCounter cluster2Abnormal = new IntCounter();
@Override
public void open(Configuration parameters) throws Exception {
getRuntimeContext().addAccumulator("cluster1Normal", cluster1Normal);
getRuntimeContext().addAccumulator("cluster1Abnormal", cluster1Abnormal);
getRuntimeContext().addAccumulator("cluster2Normal", cluster2Normal);
getRuntimeContext().addAccumulator("cluster2Abnormal", cluster2Abnormal);
}
public Tuple2<Integer, Long> map(Tuple2<Integer, Point> t2) throws Exception {
if (t2.f0 == 1) {
if ("normal.".equals(t2.f1.sourceFields.get(41))) {
cluster1Normal.add(1);
} else {
cluster1Abnormal.add(1);
}
} else if (t2.f0 == 2) {
if ("normal.".equals(t2.f1.sourceFields.get(41))) {
cluster2Normal.add(1);
} else {
cluster2Abnormal.add(1);
}
}
//对每一个簇内的点进行计数
return new Tuple2<Integer, Long>(t2.f0, 1L);
}
}
}
簇中心1: 0,tcp,http,SF,236,314,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,4,4,0.00,0.00,0.00,0.00,1.00,0.00,0.00,255,255,1.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,normal. 1
簇中心2: 0,icmp,ecr_i,SF,1032,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,120,120,0.00,0.00,0.00,0.00,1.00,0.00,0.00,255,255,1.00,0.00,1.00,0.00,0.00,0.00,0.00,0.00,smurf. 2
簇中心2移动距离: 0.826594231159133
簇中心1移动距离: 1.2939328884752181
簇中心1移动距离: 0.04597443790175165
簇中心2移动距离: 0.06020607832667806
簇中心2移动距离: 0.0022798932582670174
簇中心1移动距离: 0.0025190908117799595
簇中心1移动距离: 2.0983487049015966E-4
簇中心2移动距离: 1.841180332519695E-4
簇中心2移动距离: 1.5844853826986143E-5
簇中心1移动距离: 1.808769339554904E-5
簇中心1移动距离: 2.1330751274520463E-16
簇中心2移动距离: 1.1775693753296206E-16
(1,145222)
(2,165807)
数据点总个数: 311029
簇中心1正常点个数: 59337
簇中心1异常点个数: 85885
簇中心2正常点个数: 1256
簇中心2异常点个数: 164551
purity: 0.7198299836992692
https://blog.csdn.net/asialee_bird/article/details/80491256
https://blog.csdn.net/hxcaifly/article/details/86496243