在Kaggle手写数字数据集上使用Spark MLlib的朴素贝叶斯模型进行手写数字识别

昨天我在Kaggle上下载了一份用于手写数字识别的数据集,想通过最近学习到的一些方法来训练一个模型进行手写数字识别。这些数据集是从28×28像素大小的手写数字灰度图像中得来,其中训练数据第一个元素是具体的手写数字,剩下的784个元素是手写数字灰度图像每个像素的灰度值,范围为[0,255],测试数据则没有训练数据中的第一个元素,只包含784个灰度值。现在我打算使用Spark MLlib中提供的朴素贝叶斯算法来训练模型。

首先来设定Spark上下文的一些参数:

val conf = new SparkConf()
    .setAppName("DigitRecgonizer")
    .setMaster("local[*]")
    .set("spark.driver.memory", "10G")
val sc = new SparkContext(conf)

这样Spark上下文已经创建完毕了,那么现在来读取训练数据吧,在这里我把原本的训练数据的header去掉了,只保留了数据部,训练数据是以csv格式保存的:

val rawData = sc.textFile("file://path/train-noheader.csv")

由于数据是csv格式,所以接下来用“,”将每行数据转换成数组:

val records = rawData.map(line => line.split(","))

下面需要将这些数据处理成朴素贝叶斯能够接受的数据类型LabeledPoint ,此类型接收两个参数,第一个参数是label(标签,在这里就是具体的手写数字),第二个参数是features (特征向量,在这里是784个灰度值):

    val records = rawData.map(line => line.split(","))
    val data = records.map{ r =>
      val label = r(0).toInt
      val features = r.slice(1, r.size).map(p => p.toDouble)
      LabeledPoint(label, Vectors.dense(features))
    }

现在已经把数据都准备好了,可以开始训练模型了,在MLlib中,只需要简单地调用train 方法就能完成模型的训练:

val nbModel = NaiveBayes.train(data)

现在已经训练出了一个模型,我们看看它在训练数据集上的准确率如何,在这里我将训练数据集的特征传给模型进行训练,将得到的结果与真实的结果进行对比,然后统计出正确的条数,以此来评估模型的准确率,这应该也算是一种交叉验证吧:

    val nbTotalCorrect = data.map { point =>
      if (nbModel.predict(point.features) == point.label) 1 else 0
    }.sum
    val numData = data.count()
    val nbAccuracy = nbTotalCorrect / numData

运行完这段代码,我得到的准确率是0.8261190476190476

下面开始对测试数据进行识别了,首先读入测试数据:

val unlabeledData = sc.textFile("file://path/test-noheader.csv")

再用与之前同样的方式进行预处理:

val unlabeledRecords = unlabeledData.map(line => line.split(","))
val features = unlabeledRecords.map{ r =>
  val f = r.map(p => p.toDouble)
  Vectors.dense(f)
}

注意,测试数据中没有标签,所以将它所有数值都作为特征features

现在开始对测试数据进行识别,并把结果保存为文件:

    val predictions = nbModel.predict(features).map(p => p.toInt)
    predictions.repartition(1).saveAsTextFile("file://path/digitRec.txt")

到这里所有工作都完成了,之后我把计算出来的结果上传到Kaggle上,发现准确率在0.83左右,与我之前在训练数据集上得到的评估结果相近。

今天就到这里吧,以后可能还会寻找其他的方式来训练模型,看看效果如何。

你可能感兴趣的:(机器学习,个人项目,神经网络与机器学习笔记)