通过学习《Hadoop大数据开发基础》这本书,整理了一下书本上的项目案例。让自己再梳理一下流程,也希望能给有需要的人提供一定的帮助,写的不好的希望大家提出来,一起进步。
KNN算法,全称是K Nearest Neighbor算法,即K最近邻分类算法。 其中的K表示最接近自己的K个数据样本。
比如,有一个样本空间里的样本已经分成了几个类型,然后,给定一个待分类的数据,通过计算接近自己最近的K个样本来判断这个待分类数据属于哪个分类。
简单理解就是由那离自己最近的K个点来投票决定待分类数据归为哪一类。
项目所需的数据:
(百度云链接老是被和谐,这次用微云分享一下)
链接:https://share.weiyun.com/1uVK7mpg 密码:wc677y
用户对电影的部分评分数据 ratings. dat如图所示。该数据包含4个字段,即 UserID(用户ID) MovieID(电影ID) Rating(评分)及 Timestamp(时间戳)其中, UserID的范围是1~6040, MovieID的范围是1~3952, Rating采用5分好评制度,即最高分为5分,最低分为1分。
已知性别的用户信息部分数据 users.dat如图所示。该数据包含5个字段,分别为 UserID(用户ID) Gender(性别)Age(年龄Occupation(职业)以及Zip-code(编码)其中, Occupation字段代表的是21种不同的职业类型,Age字段记录的也并不是用户的实际年龄,而是一个年龄段,例如,1代表的是18岁以下,具体的解释请参考 README。
(1)根据UserID字段字段连接ratings.dat数据和users.dat数据,连接结果得到一份包含UserID(用户ID),Gender(性别),Age(年龄),Occupation(职业),Zip-code(编码),MovieID(电影ID)的数据。
只需下载上边百度云链接里的ratings_users.jar包。将JAR包上传到 Linux的opt目录下,在HDFS上新建文件夹/movie,将 ratings.dt、 users.dat传到/movie下,将程序运行结果保存在/movie/ratingsusers目录下。
命令如下:
hadoop jar /opt/ratings_users.jar demo. RatingsAndusers /movie/users.dat/movie/ratings.dat/movie/ratings_ users
运行之后得到
(2)同理,根据MovieID连接movies.dat数据和/movie/ratings_users/part-m-00000上的数据,连接结果得到一份包含UserID(用户ID),Gender(性别),Age(年龄),Occupation(职业),Zip-code(编码),MovieID(电影ID),Genres(电影类型)。
然后把百度云链接里的users_movies.jar包下载,Linux的opt目录下,将movies.dat数据上传到HDFS的/movies目录下,运行结果保存在/movie/users_movies。
命令如下:
hadoop jar /opt/users_movies.jar demo.UserAndMovies /movie/movies.dat /movie/ratings_users/part-m-00000 /movie/users_movies
结果如下:
(3)对每个用户看过电影类型进行统计。对Gender(性别)做一步转换,如果是女性(F)则用1标记,如果是男性(M)则用0标记.
这一步的处理看Map端和Reduce端的处理流程:
对每个用户看过的电影类型进行统计的Mapper类及Reducer类代码:
public class MoviesGenresMapper extends Mapper<LongWritable, Text, UserAndGender, Text> {
private UserAndGender user_gender=new UserAndGender();
private String splitter="";
private Text genres=new Text();
@Override
protected void setup(Mapper<LongWritable, Text, UserAndGender, Text>.Context context)
throws IOException, InterruptedException {
splitter=context.getConfiguration().get("SPLITTER");
}
@Override
protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, UserAndGender, Text>.Context context)
throws IOException, InterruptedException {
String[] val=value.toString().split(splitter);
user_gender.setUserID(val[0]);
if(val[1].equals("M")){
//性别为M则用0标记
user_gender.setGender(0);
}else{
//性别为F则用1标记
user_gender.setGender(1);
}
user_gender.setAge(Integer.parseInt(val[2]));
user_gender.setOccupation(val[3]);
user_gender.setZip_code(val[4]);
genres.set(val[6]);
context.write(user_gender, genres);
}
}
public class MoviesGenresReducer extends Reducer<UserAndGender, Text, Text, NullWritable> {
@Override
protected void reduce(UserAndGender key, Iterable<Text> value,
Reducer<UserAndGender, Text, Text, NullWritable>.Context context) throws IOException, InterruptedException {
//初始化一个HashMap集合,集合中的键为18种电影类型,每个键对应的值为0
HashMap<String,Integer> genresCounts=new HashMap<String,Integer>();
String[] genreslist={"Action","Adventure","Animation","Children's","Comedy","Crime","Documentary","Drama",
"Fantasy","Film-Noir","Horror","Musical","Mystery","Romance","Sci-Fi","Thriller","War","Western"
};
for(int i=0;i<genreslist.length;i++){
if(!genresCounts.containsKey(genreslist[i])){
genresCounts.put(genreslist[i], 0);
}
}
//遍历值列表
for (Text val : value) {
//对每个元素进行分割
String[] genres=val.toString().split("\\|");
for(int i=0;i<genres.length;i++){
//如果HashMap元素的键包含分割结果的元素,则该键对应的值加1
if(genresCounts.containsKey(genres[i])){
genresCounts.put(genres[i], genresCounts.get(genres[i])+1);
}
}
}
//将HashMap集合中所有键对应的值根据逗号连接成字符串
String result="";
for(Map.Entry<String, Integer> kv:genresCounts.entrySet()){
if(result.length()==0){
result=kv.getValue().toString();
}else{
result=result+","+kv.getValue();
}
}
context.write(new Text(key.toString()+","+result), NullWritable.get());
}
}
处理之后得到结果:
缺失值和异常值的处理方式如下图:
处理缺失值和异常值的代码:
public class DataProcessingMapper extends Mapper<LongWritable, Text, Text, NullWritable> {
private String splitter="";
enum DataProcessingCounter{
NullData,
AbnormalData
}
@Override
protected void setup(Mapper<LongWritable, Text, Text, NullWritable>.Context context)
throws IOException, InterruptedException {
splitter=context.getConfiguration().get("SPLITTER");
}
@Override
protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, NullWritable>.Context context)
throws IOException, InterruptedException {
String[] val = value.toString().split(splitter);
for(int i=5;i<val.length;i++){
//判断每个字段的值是否是空值,若是则用0替换
if(val[i].equals("") || val[i].equals("null") || val[i].equals("NULL") || val[i].equals("NAN")){
context.getCounter(DataProcessingCounter.NullData).increment(1);
val[i]="0";
}else{
context.getCounter(DataProcessingCounter.NullData).increment(0);
}
//判断每个字段的值是否是异常值,若是则用0替换
if(Integer.parseInt(val[i])<0){
context.getCounter(DataProcessingCounter.AbnormalData).increment(1);
val[i]="0";
}else{
context.getCounter(DataProcessingCounter.AbnormalData).increment(0);
}
}
String result="";
//重新将字符创数组val拼接成字符串
for(int i=0;i<val.length;i++){
if(i==0){
result=val[i];
}else{
result=result+splitter+val[i];
}
}
context.write(new Text(result), NullWritable.get());
}
}
一般来说分类算法由3个过程:
(1)通过归纳分析训练样本集建立分类器
(2)用验证数据集来选择最优的模型参数
(3)用已知类别的测试样本集评估分类器的准确性
本项目在建立M电影用户分类器之前,将处理之后的数据按8:1:1的比例随机划分数据集为训练数据集、测试数据集、验证数据集。
读取HDFS的数据并统计记录数的方法:
/**
* 读取原始数据并统计数据的记录数
* @param fs
* @param path
* @return
* @throws Exception
*/
public static int getSize(FileSystem fs,Path path) throws Exception{
int count=0;
FSDataInputStream is=fs.open(path);
BufferedReader br=new BufferedReader(new InputStreamReader(is));
String line="";
while((line=br.readLine())!=null){
count++;
}
br.close();
is.close();
return count;
}
/**
*随机获取 80%原始数据的对应下标
* @param count
* @return
*/
public static Set<Integer> trainIndex(int count){
Set<Integer> train_index=new HashSet<Integer>();
int trainSplitNum=(int)(count*0.8);
Random random=new Random();
while(train_index.size()<trainSplitNum){
int a=random.nextInt(count);
train_index.add(a);
}
return train_index;
}
/**
* 随机获取10%原始数据对应的下标
* @param count
* @param train_index
* @return
*/
public static Set<Integer> validateIndex(int count,Set<Integer> train_index){
Set<Integer> validate_index=new HashSet<Integer>();
int validateSplitNum=count-(int)(count*0.9);
Random random=new Random();
while(validate_index.size()<validateSplitNum){
int a=random.nextInt(count);
if(!train_index.contains(a)){
validate_index.add(a);
}
}
return validate_index;
}
设置训练集的存储路径为/movie/trainData,验证数据集的存储路径为/movie/validateData,测试数据集的存储路径为/movie/testData。
将数据写入HDFS:
public class SplitData {
public static void main(String[] args) throws Exception {
Configuration conf=new Configuration();
conf.set("fs.defaultFS", "master:8020");
FileSystem fs=FileSystem.get(conf);
//获取预处理之后的电影数据路径
Path moviedata=new Path("/movie/processing_out/part-m-00000");
//得到电影数据大小
int datasize=getSize(fs, moviedata);
//得到train数据对应原始下标
Set<Integer> train_index=trainIndex(datasize);
System.out.println(train_index.size());
//得到validate数据对应原始数据的下标
Set<Integer> validate_index=validateIndex(datasize,train_index);
System.out.println(validate_index.size());
//训练数据存放的路径
Path train=new Path("hdfs://master:8020/movie/trainData");
fs.delete(train,true);
FSDataOutputStream os1=fs.create(train);
BufferedWriter bw1=new BufferedWriter(new OutputStreamWriter(os1));
//测试数据存放的路径
Path test=new Path("hdfs://master:8020/movie/testData");
fs.delete(test,true);
FSDataOutputStream os2=fs.create(test);
BufferedWriter bw2=new BufferedWriter(new OutputStreamWriter(os2));
//验证数据存放的路径
Path validate=new Path("hdfs://master:8020/movie/validateData");
fs.delete(validate,true);
FSDataOutputStream os3=fs.create(validate);
BufferedWriter bw3=new BufferedWriter(new OutputStreamWriter(os3));
//读取数据并将数据分为训练数据、测试数据以及验证数据写入到HDFS
FSDataInputStream is=fs.open(moviedata);
BufferedReader br=new BufferedReader(new InputStreamReader(is));
String line="";
int sum=0;
int trainsize=0;
int testsize=0;
int validatesize=0;
while((line=br.readLine())!=null){
sum+=1;
if(train_index.contains(sum)){
trainsize+=1;
bw1.write(line.toString());
bw1.newLine();
}else if(validate_index.contains(sum)){
validatesize+=1;
bw3.write(line.toString());
bw3.newLine();
}else{
testsize+=1;
bw2.write(line.toString());
bw2.newLine();
}
}
bw1.close();
os1.close();
bw2.close();
os2.close();
bw3.close();
os3.close();
br.close();
is.close();
fs.close();
}
算法描述:
1.自定义值类型表示距离和类型,由于KNN算法是计算测试数据与已知类别的训练数据之间的距离,找到距离与测试数据最近的K个训练数据,再根据这些训练所属的类别的众数来判断测试数据的类别。所以在map阶段需要将测试数据与训练数据的距离及该训练数据的类别作为值输出,程序可以使用Hadoop内置的数据类型Text作为值类型输出距离及类别,但为了提高程序的执行效率,建议自定义值类型表示距离和类别。
2.map阶段,setup函数读取测试数据。在map函数里读取每条训练数据,遍历测试数据,计算读取进来的训练记录与每条测试数据的距离,计算距离采用的是欧式距离的计算方法,map输出的键是每条测试数据,输出的值是该测试数据与读取的训练数据的距离和训练数据的类别。
3.reduce阶段,函数初始化参数值,函数对相同键的值根据距离进行升序排序,取出前个值,输出读取进来的键和这个值中类别的众数
public class DistanceAndLabel implements Writable{
private double distance;
private String label;
public DistanceAndLabel() {
}
public DistanceAndLabel(double distance,String label) {
this.distance=distance;
this.label=label;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
/**
*先读取距离,再读取类别
*/
@Override
public void readFields(DataInput in) throws IOException {
this.distance=in.readDouble();
this.label=in.readUTF();
}
/**
* 先把distnce写入out输出流
* 再把label写入out输出流
*/
@Override
public void write(DataOutput out) throws IOException {
out.writeDouble(distance);
out.writeUTF(label);
}
/**
* 使用空格将距离和类别连接成字符串
*/
@Override
public String toString() {
return this.distance+","+this.label;
}
}
public class MovieClassifyMapper extends Mapper<LongWritable, Text, Text, DistanceAndLabel> {
private DistanceAndLabel distance_label=new DistanceAndLabel();
private String splitter="";
ArrayList<String> testData=new ArrayList<String>();
private String testPath="";
@Override
protected void setup(Mapper<LongWritable, Text, Text, DistanceAndLabel>.Context context)
throws IOException, InterruptedException {
Configuration conf=context.getConfiguration();
splitter=conf.get("SPLITTER");
testPath=conf.get("TESTPATH");
//读取测试数据存于列表testData中
FileSystem fs=FileSystem.get(conf);
FSDataInputStream is=fs.open(new Path(testPath));
BufferedReader br=new BufferedReader(new InputStreamReader(is));
String line="";
while((line=br.readLine())!=null){
testData.add(line);
}
is.close();
br.close();
}
@Override
protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, DistanceAndLabel>.Context context)
throws IOException, InterruptedException {
double distance=0.0;
String[] val=value.toString().split(splitter);
String[] singleTrainData=Arrays.copyOfRange(val, 5, val.length);
String label=val[1];
for (String td: testData) {
String[] test=td.split(splitter);
String[] singleTestData=Arrays.copyOfRange(test, 5, test.length);
distance=Distance(singleTrainData,singleTestData);
distance_label.setDistance(distance);
distance_label.setLabel(label);
context.write(new Text(td), distance_label);
}
}
/**
* 计算训练数据与测试数据的距离
* @param singleTrainData
* @param singleTestData
* @return
*/
private double Distance(String[] singleTrainData, String[] singleTestData) {
double sum=0.0;
for(int i=0;i<singleTrainData.length;i++){
sum+=Math.pow(Double.parseDouble(singleTrainData[i]), Double.parseDouble(singleTestData[i]));
}
return Math.sqrt(sum);
}
}
public class MovieClassifyReducer extends Reducer<Text, DistanceAndLabel, Text, NullWritable> {
private int k=0;
@Override
protected void setup(Reducer<Text, DistanceAndLabel, Text, NullWritable>.Context context)
throws IOException, InterruptedException {
//初始化K值
k=context.getConfiguration().getInt("K",3);
}
@Override
protected void reduce(Text key, Iterable<DistanceAndLabel> value,
Reducer<Text, DistanceAndLabel, Text, NullWritable>.Context context) throws IOException, InterruptedException {
String label=getMost(getTopK(sort(value)));
context.write(new Text(label+","+key), NullWritable.get());
}
/**
* 得到列表中类别的众数
* @param topK
* @return
*/
private String getMost(List<String> topK) {
HashMap<String,Integer> labelTimes=new HashMap<String,Integer>();
for (String str : topK) {
String label=str.substring(str.lastIndexOf(",")+1,str.length());
if(labelTimes.containsKey(label)){
labelTimes.put(label, labelTimes.get(label)+1);
}else{
labelTimes.put(label, 1);
}
}
int maxInt=Integer.MIN_VALUE;
String mostLabel="";
for(Map.Entry<String, Integer> kv:labelTimes.entrySet()){
if(kv.getValue()>maxInt){
maxInt=kv.getValue();
mostLabel=kv.getKey();
}
}
return mostLabel;
}
/**
* 取出列表中的前K个值
* @param sort
* @return
*/
private List<String> getTopK(List<String> sort) {
return sort.subList(0, k);
}
/**
* 根据距离升序排序
* @param value
* @return
*/
private List<String> sort(Iterable<DistanceAndLabel> value) {
ArrayList<String> result=new ArrayList<String>();
for(DistanceAndLabel val:value){
result.add(val.toString());
}
String[] tmp=new String[result.size()];
result.toArray(tmp);
Arrays.sort(tmp, new Comparator<String>(){
@Override
public int compare(String o1, String o2) {
double o1D=Double.parseDouble(o1.substring(0, o1.indexOf(",")));
double o2D=Double.parseDouble(o2.substring(0, o2.indexOf(",")));
if(o1D>o2D){
return 1;
}else if(o1D<o2D){
return -1;
}else{
return 0;
}
}});
return Arrays.asList(tmp);
}
}
public class MovieClassify extends Configured implements Tool{
@Override
public int run(String[] args) throws Exception {
if(args.length!=5){
System.err.println("demo.MovieClassify );
System.exit(-1);
}
Configuration conf=getMyConfiguration();
conf.setInt("K", Integer.parseInt(args[3]));
conf.set("SPLITTER",args[4]);
conf.set("TESTPATH", args[0]);
Job job=Job.getInstance(conf, "movie_knn");
job.setJarByClass(MovieClassify.class);//设置主类
job.setMapperClass(MovieClassifyMapper.class);//设置Mapper类
job.setReducerClass(MovieClassifyReducer.class);//设置Reducer类
job.setMapOutputKeyClass(Text.class);//设置Mapper输出的键类型
job.setMapOutputValueClass(DistanceAndLabel.class);//设置Mapper输出的值类型
job.setOutputKeyClass(Text.class);//设置Reducer输出的键类型
job.setOutputValueClass(NullWritable.class);//设置Reducer输出的值类型
FileInputFormat.addInputPath(job, new Path(args[1]));//设置输入路径
FileSystem.get(conf).delete(new Path(args[2]), true);//删除输出路径
FileOutputFormat.setOutputPath(job, new Path(args[2]));//设置输出路径
return job.waitForCompletion(true)?-1:1;//提交任务
}
public static void main(String[] args) {
String[] myArgs={
"/movie/testData",
"/movie/trainData",
"/movie/knnout",
"3",
","
};
try {
ToolRunner.run(getMyConfiguration(), new MovieClassify(), myArgs);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* 设置连接Hadoop集群的配置
* @return
*/
public static Configuration getMyConfiguration(){
Configuration conf = new Configuration();
conf.setBoolean("mapreduce.app-submission.cross-platform",true);
conf.set("fs.defaultFS", "hdfs://master:8020");// 指定namenode
conf.set("mapreduce.framework.name","yarn"); // 指定使用yarn框架
String resourcenode="master";
conf.set("yarn.resourcemanager.address", resourcenode+":8032"); // 指定resourcemanager
conf.set("yarn.resourcemanager.scheduler.address",resourcenode+":8030");// 指定资源分配器
conf.set("mapreduce.jobhistory.address",resourcenode+":10020");
conf.set("mapreduce.job.jar",JarUtil.jar(MovieClassify.class));
return conf;
}
}
public class JarUtil {
public static String jar(Class<?> cls){// 验证ok
String outputJar =cls.getName()+".jar";
String input = cls.getClassLoader().getResource("").getFile();
input= input.substring(0,input.length()-1);
input = input.substring(0,input.lastIndexOf("/")+1);
input =input +"bin/";
jar(input,outputJar);
return outputJar;
}
private static void jar(String inputFileName, String outputFileName){
JarOutputStream out = null;
try{
out = new JarOutputStream(new FileOutputStream(outputFileName));
File f = new File(inputFileName);
jar(out, f, "");
}catch (Exception e){
e.printStackTrace();
}finally{
try {
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
private static void jar(JarOutputStream out, File f, String base) throws Exception {
if (f.isDirectory()) {
File[] fl = f.listFiles();
base = base.length() == 0 ? "" : base + "/"; // 注意,这里用左斜杠
for (int i = 0; i < fl.length; i++) {
jar(out, fl[ i], base + fl[ i].getName());
}
} else {
out.putNextEntry(new JarEntry(base));
FileInputStream in = new FileInputStream(f);
byte[] buffer = new byte[1024];
int n = in.read(buffer);
while (n != -1) {
out.write(buffer, 0, n);
n = in.read(buffer);
}
in.close();
}
}
}
准确率的计算公式:
准 确 率 = 正 确 识 别 的 个 体 总 数 ÷ 识 别 出 的 个 体 总 数 准确率=正确识别的个体总数÷识别出的个体总数 准确率=正确识别的个体总数÷识别出的个体总数
评价思路:
评价代码之Mapper类:
public class ValidateMapper extends Mapper<LongWritable, Text, NullWritable, Text> {
private String splitter="";
@Override
protected void setup(Mapper<LongWritable, Text, NullWritable, Text>.Context context)
throws IOException, InterruptedException {
splitter=context.getConfiguration().get("SPLITTER");
}
@Override
protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, NullWritable, Text>.Context context)
throws IOException, InterruptedException {
String[] val=value.toString().split(splitter);
context.write(NullWritable.get(), new Text(val[0]+splitter+val[2]));
}
}
Reducer类:
public class ValidateReducer extends Reducer<NullWritable, Text, DoubleWritable, NullWritable> {
private String splitter="";
@Override
protected void setup(Reducer<NullWritable, Text, DoubleWritable, NullWritable>.Context context)
throws IOException, InterruptedException {
splitter=context.getConfiguration().get("SPLITTER");
}
@Override
protected void reduce(NullWritable key, Iterable<Text> value,
Reducer<NullWritable, Text, DoubleWritable, NullWritable>.Context context)
throws IOException, InterruptedException {
//初始化sum记录预测分类正确的个数
int sum=0;
//初始化count记录所有分类结果的记录数,也即测试数据的记录数
int count=0;
for (Text val: value) {
count++;
String predictLabel=val.toString().split(splitter)[0];
String trueLabel=val.toString().split(splitter)[1];
//判断预测分类的类别是否与正确分类的类别一样
if(predictLabel.equals(trueLabel)){
sum+=1;
}
}
//计算正确率
double accuracy=(double)sum/count;
context.write(new DoubleWritable(accuracy), NullWritable.get());
}
}
驱动类:
public class Validate extends Configured implements Tool{
@Override
public int run(String[] args) throws Exception {
if(args.length!=3){
System.err.println("demo01.Validate );
System.exit(-1);
}
Configuration conf=getMyConfiguration();
conf.set("SPLITTER", args[2]);
Job job=Job.getInstance(conf, "validate");
job.setJarByClass(Validate.class);//设置主类
job.setMapperClass(ValidateMapper.class);//设置Mapper类
job.setReducerClass(ValidateReducer.class);//设置Reducer类
job.setMapOutputKeyClass(NullWritable.class);//设置Mapper输出的键格式
job.setMapOutputValueClass(Text.class);//设置Mapper输出的值格式
job.setOutputKeyClass(DoubleWritable.class);//设置Reducer输出的键格式
job.setOutputValueClass(NullWritable.class);//设置Reducer输出的值格式
FileInputFormat.addInputPath(job, new Path(args[0]));//设置输入路径
FileSystem.get(conf).delete(new Path(args[1]),true);//设置删除输出路径
FileOutputFormat.setOutputPath(job, new Path(args[1]));//设置输出路径
return job.waitForCompletion(true)?-1:1;
}
public static void main(String[] args) {
String[] myArgs={
"/movie/knnout/part-r-00000",
"/movie/validateout",
","
};
try {
ToolRunner.run(getMyConfiguration(), new Validate(), myArgs);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* 设置连接Hadoop集群的配置
* @return
*/
public static Configuration getMyConfiguration(){
Configuration conf = new Configuration();
conf.setBoolean("mapreduce.app-submission.cross-platform",true);
conf.set("fs.defaultFS", "hdfs://master:8020");// 指定namenode
conf.set("mapreduce.framework.name","yarn"); // 指定使用yarn框架
String resourcenode="master";
conf.set("yarn.resourcemanager.address", resourcenode+":8032"); // 指定resourcemanager
conf.set("yarn.resourcemanager.scheduler.address",resourcenode+":8030");// 指定资源分配器
conf.set("mapreduce.jobhistory.address",resourcenode+":10020");
conf.set("mapreduce.job.jar",JarUtil.jar(Validate.class));
return conf;
}
}
KNN算法的K值会对分类结果产生重大影响。
下图是K值分别为3,4,5,6,7对应的准确率,从图中可以看出K值为3是准确率是最高的。
虽然在K=3,4,5,6,7中,K=3 的准确率是最高的,但并不意味着K=3 得到的模型就是最好的分类器。
对与K值的选取,可以利用验证数据集及迭代的算法思想,其思路为:
针对上述选择最优K值的思路,编写一个ALLJob类来完成选择最优K值。ALLJob类中只有一个main方法,在该方法中循环K值,每循环一次则需调用实现用户性别分类的MapReduce程序,同时还需调用评价分类准确性的MapReducue程序。
选择最优K值代码:
public class AllJob {
public static void main(String[] args) throws IOException {
Configuration conf=new Configuration();
conf.set("fs.defaultFS", "master:8020");
FileSystem fs=FileSystem.get(conf);
double maxAccuracy=0.0;
int bestK=0;
int[] k={2,3,5,9,15,30,55,70,80,100};
for(int i=0;i<k.length;i++){
double accuracy=0.0;
String[] classifyArgs={
"/movie/validateData",
"/movie/trainData",
"/movie/knnout",
String.valueOf(k[i]),
","
};
try {
ToolRunner.run(demo.MovieClassify.getMyConfiguration(), new demo.MovieClassify(), classifyArgs);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
String[] validateArgs={
"/movie/knnout/part-r-00000",
"/movie/validateout",
","
};
try {
ToolRunner.run(demo01.Validate.getMyConfiguration(),new demo01.Validate(),validateArgs);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
FSDataInputStream is=fs.open(new Path("/movie/validateout/part-r-00000"));
BufferedReader br=new BufferedReader(new InputStreamReader(is));
String line="";
while((line=br.readLine())!=null){
accuracy=Double.parseDouble(line);
}
br.close();
is.close();
if(accuracy>maxAccuracy){
maxAccuracy=accuracy;
bestK=k[i];
}
System.out.println("K="+k[i]+"\t"+"accuracy="+accuracy);
}
System.out.println("最优K值是:"+bestK+"\t"+"最优K值对应的准确率:"+maxAccuracy);
}
}
优点:
缺点:
解决办法:(1)属性加权;(2)剔除不相关的属性。