spark的NaiveBayes中文文本分类

中文分词使用的ANSJ工具,需要下载两个jar包:ansj_seg和nlp-lang,ansj中文分词的jar下载地址:

ansj_seg jar下载地址:

https://oss.sonatype.org/content/repositories/releases/org/ansj/ansj_seg/

nlp-lang的jar下载地址:

https://oss.sonatype.org/content/repositories/releases/org/nlpcn/nlp-lang/

中文分词:

import java.io.InputStream
import java.util

import org.ansj.domain.Result
import org.ansj.recognition.impl.StopRecognition
import org.ansj.splitWord.analysis.ToAnalysis
import org.ansj.util.MyStaticValue
import org.apache.spark.{SparkConf, SparkContext}
import org.nlpcn.commons.lang.tire.domain.{Forest, Value}
import org.nlpcn.commons.lang.tire.library.Library
import org.nlpcn.commons.lang.util.IOUtil

class ChineseSegment extends Serializable {

  @transient private val sparkConf: SparkConf = new SparkConf().setAppName("chinese segment")
  @transient private val sparkContext: SparkContext = SparkContext.getOrCreate(sparkConf)

  private val stopLibRecog = new StopLibraryRecognition
  private val stopLib: util.ArrayList[String] = stopLibRecog.stopLibraryFromHDFS(sparkContext)
  private val selfStopRecognition: StopRecognition = stopLibRecog.stopRecognitionFilter(stopLib)

  private val dicUserLibrary = new DicUserLibrary
  @transient private val aListDicLibrary: util.ArrayList[Value] = dicUserLibrary.getUserLibraryList(sparkContext)
  @transient private val dirLibraryForest: Forest = Library.makeForest(aListDicLibrary)

  /**中文分词和模式识别*/
  def cNSeg(comment : String) : String = {

    val result: Result = ToAnalysis.parse(comment,dirLibraryForest).recognition(selfStopRecognition)
    result.toStringWithOutNature(" ")
  }


}


/**停用词典识别:
  * 格式: 词语  停用词类型[可以为空]  使用制表符Tab进行分割
  * 如:
  * #
  * v nature
  * .*了 regex
  *
  * */

class StopLibraryRecognition extends Serializable {

  def stopRecognitionFilter(arrayList: util.ArrayList[String]): StopRecognition ={

    MyStaticValue.isQuantifierRecognition = true //数字和量词合并

    val stopRecognition = new StopRecognition

    //识别评论中的介词(p)、叹词(e)、连词(c)、代词(r)、助词(u)、字符串(x)、拟声词(o)
    stopRecognition.insertStopNatures("p", "e", "c", "r", "u", "x", "o")

    stopRecognition.insertStopNatures("w")  //剔除标点符号

    //剔除以中文数字开头的,以一个字或者两个字为删除单位,超过三个的都不删除
    stopRecognition.insertStopRegexes("^一.{0,2}","^二.{0,2}","^三.{0,2}","^四.{0,2}","^五.{0,2}",
      "^六.{0,2}","^七.{0,2}","^八.{0,2}","^九.{0,2}","^十.{0,2}")

    stopRecognition.insertStopNatures("null") //剔除空

    stopRecognition.insertStopRegexes(".{0,1}")  //剔除只有一个汉字的

    stopRecognition.insertStopRegexes("^[a-zA-Z]{1,}")  //把分词只为英文字母的剔除掉

    stopRecognition.insertStopWords(arrayList)  //添加停用词

    stopRecognition.insertStopRegexes("^[0-9]+") //把分词只为数字的剔除

    stopRecognition.insertStopRegexes("[^a-zA-Z0-9\u4e00-\\u9fa5]+")  //把不是汉字、英文、数字的剔除

    stopRecognition
  }


  def stopLibraryFromHDFS(sparkContext: SparkContext): util.ArrayList[String] ={
    /** 获取stop.dic文件中的数据 方法二:
      * 在集群上运行的话,需要把stop的数据放在hdfs上,这样集群中所有的节点都能访问到停用词典的数据 */
    val stopLib: Array[String] = sparkContext.textFile("hdfs://zysdmaster000:8020/data/library/stop.dic").collect()
    val arrayList: util.ArrayList[String] = new util.ArrayList[String]()
    for (i<- 0 until stopLib.length)arrayList.add(stopLib(i))

    arrayList

  }


  def getStopLibraryFromLocal: StopRecognition ={

    /** 获取stop.dic文件中的数据,此方法不适用在集群上运行 方法一:
      * 如果在本地运行的话,可以把stop.dic文件方法src目录(sources格式)下,使用映射方法获取,
      * “/library/stop.dic”和“/stop.dic”两种方式均能获取stop.dic中的数据 */
    val stopRecognition: StopRecognition = new StopRecognition

    stopRecognition.insertStopNatures("w")

    val stream: InputStream = StopLibraryRecognition.this.getClass.getResourceAsStream("/stop.dic")
    val strings: util.List[String] = IOUtil.readFile2List(stream,IOUtil.UTF8)
    stopRecognition.insertStopWords(strings)

  }
}


/**用户自定义词典:
  * 格式:词语 词性  词频
  * 词语、词性和词频用制表符分开(Tab)
  *
  * */
class DicUserLibrary extends Serializable {

  def getUserLibraryList(sparkContext: SparkContext): util.ArrayList[Value] = {
    /** 获取userLibrary.dic文件中的数据 方法二:
      * 在集群上运行的话,需要把userLibrary的数据放在hdfs上,这样集群中所有的节点都能访问到user library的数据 */
    val va: Array[String] = sparkContext.textFile("hdfs://zysdmaster000:8020/data/library/userLibrary.dic").collect()
    val arrayList: util.ArrayList[Value] = new util.ArrayList[Value]()
    for (i <- 0 until va.length)arrayList.add(new Value(va(i)))
    arrayList
  }

  def getUserLibraryForest: Forest = {

    /** 获取userLibrary.dic文件中的数据,此方法不适用在集群上运行 方法一:
      * 如果在本地运行的话,可以把userLibrary.dic文件方法src目录(sources格式)下,使用映射方法获取,
      * “/library/userLibrary.dic”和“/userLibrary.dic”两种方式均能获取userLibrary.dic中的数据 */
    val stream: InputStream = DicUserLibrary.this.getClass.getResourceAsStream("/library/userLibrary.dic")
    println(stream)
    /** userLibrary.dic文件放在src目录(sources格式)下 */
    val forestLibrary: Forest = Library.makeForest(DicUserLibrary.this.getClass.getResourceAsStream("/userLibrary.dic"))
    forestLibrary
  }
}

 

TF-IDF计算并调用spark的NaiveBayes方法:

import java.sql.DriverManager
import java.util.Properties

import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, Tokenizer}
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}

object DoubanSentimentAnalysis {

  def main(args: Array[String]): Unit = {
    run
  }

  def run: Unit ={
    val doubanSentimentAnalysis = new DoubanSentimentAnalysis
    val table_name_label: Array[String] = Array("short","long")
    for (i <- 0 until table_name_label.length) {
      doubanSentimentAnalysis.category2Mysql(table_name_label(i))
    }
    println(".................. douban sentiment analysis has finished ..................")
  }

}

class DoubanSentimentAnalysis extends Serializable {

  @transient private val sparkConf: SparkConf = new SparkConf().setAppName("douban sentiment analysis")
  @transient private val sparkContext: SparkContext = new SparkContext(sparkConf)
  private val sqlContext: SQLContext = new SQLContext(sparkContext)

  import sqlContext.implicits._

  case class RowDataRecord(cid: String, text: String)

  def referenceModel(label: String): RDD[(Integer, Int)] = {

    val url = "jdbc:mysql://mysql_IP/databaseName?useUnicode=true&characterEncoding=UTF-8"

    //把用户名和密码赋予properties
    val properties: Properties = new Properties
    properties.setProperty("user", "myDatabaseName")
    properties.setProperty("password", "myDatabasePassword")

    val table_name = label + "_comment_douban_copy"

    //根据url、mysql表、properties通过jdbc连接mysql
    val comment_douban: DataFrame = sqlContext.read.jdbc(url, table_name, properties)

    val dataFrame: DataFrame = comment_douban.select("cid", "comment").repartition(200)

    val segment = new ChineseSegment

    val segmentDFrame = dataFrame.map { f =>
      val str = segment.cNSeg(f.getString(1))
      (f(0).toString, str)
    }.toDF("cid", "segment")

    //对分词结果为空的进行过滤,并把过滤结果存储到case类RawDataRecord中,转换为DataFrame
    val partsRDD = segmentDFrame.filter("segment != ''").select("cid", "segment").map { h =>
      RowDataRecord(h(0).toString, h(1).toString)
    }.toDF()

    //把分词结果转换为数组
    val tokenizer: Tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
    val wordData: DataFrame = tokenizer.transform(partsRDD)

    //对分词结果进行TF计算
    val hashingTF: HashingTF = new HashingTF().setInputCol("words").setOutputCol("tfFeatures")
    val tfDFrame: DataFrame = hashingTF.transform(wordData)

    //根据获取的TF值进行IDF计算
    val idf: IDF = new IDF().setInputCol("tfFeatures").setOutputCol("rfidfFeatures")
    val idfModel: IDFModel = idf.fit(tfDFrame)
    val dfidfDFrame: DataFrame = idfModel.transform(tfDFrame)


    //评论对应的DF-IDF
    val cidDFIDF: DataFrame = dfidfDFrame.select($"cid", $"rfidfFeatures")

    //调取已经训练好的模型
    val naiveBayesModel: NaiveBayesModel = NaiveBayesModel.load(sparkContext, "hdfs://hdfsMasterIp:/data/naive_bayes_model")
    //对每条评论进行分类。vow(0)是评论的id号,vow(1)是评论的tf-idf值
    val cidCategory = cidDFIDF.map { vow =>
      (Integer.valueOf(vow(0).toString), naiveBayesModel.predict(vow.getAs[Vector](1).toSparse).toInt)
    }

    cidCategory
  }

  def category2Mysql(label: String): Unit = {

    val tidCategory: RDD[(Integer, Int)] = referenceModel(label)

    val mysqlURL = "jdbc:mysql://mysql_IP/database?useUnicode=true&rewriteBatchedStatements=true&characterEncoding=UTF-8"
    val mysql = "INSERT INTO "+label+"_comment_douban_copy(cid,attitude) VALUES(?,?) ON DUPLICATE KEY UPDATE cid = values(cid),attitude = values(attitude)"

    tidCategory.foreachPartition(it => {

      val conn = DriverManager.getConnection(mysqlURL, "databaseName", "databasePassword")
      val pstat = conn.prepareStatement(mysql)
      for (obj <- it) {
        pstat.setInt(1, obj._1)
        pstat.setInt(2, obj._2)

        //添加到批量中
        pstat.addBatch()
      }

      //批量处理
      pstat.executeBatch()
//      pstat.close()
//      conn.close()

    })

  }
}

 

你可能感兴趣的:(spark,机器学习)