中文分词使用的ANSJ工具,需要下载两个jar包:ansj_seg和nlp-lang,ansj中文分词的jar下载地址:
ansj_seg jar下载地址:
https://oss.sonatype.org/content/repositories/releases/org/ansj/ansj_seg/
nlp-lang的jar下载地址:
https://oss.sonatype.org/content/repositories/releases/org/nlpcn/nlp-lang/
中文分词:
import java.io.InputStream
import java.util
import org.ansj.domain.Result
import org.ansj.recognition.impl.StopRecognition
import org.ansj.splitWord.analysis.ToAnalysis
import org.ansj.util.MyStaticValue
import org.apache.spark.{SparkConf, SparkContext}
import org.nlpcn.commons.lang.tire.domain.{Forest, Value}
import org.nlpcn.commons.lang.tire.library.Library
import org.nlpcn.commons.lang.util.IOUtil
class ChineseSegment extends Serializable {
@transient private val sparkConf: SparkConf = new SparkConf().setAppName("chinese segment")
@transient private val sparkContext: SparkContext = SparkContext.getOrCreate(sparkConf)
private val stopLibRecog = new StopLibraryRecognition
private val stopLib: util.ArrayList[String] = stopLibRecog.stopLibraryFromHDFS(sparkContext)
private val selfStopRecognition: StopRecognition = stopLibRecog.stopRecognitionFilter(stopLib)
private val dicUserLibrary = new DicUserLibrary
@transient private val aListDicLibrary: util.ArrayList[Value] = dicUserLibrary.getUserLibraryList(sparkContext)
@transient private val dirLibraryForest: Forest = Library.makeForest(aListDicLibrary)
/**中文分词和模式识别*/
def cNSeg(comment : String) : String = {
val result: Result = ToAnalysis.parse(comment,dirLibraryForest).recognition(selfStopRecognition)
result.toStringWithOutNature(" ")
}
}
/**停用词典识别:
* 格式: 词语 停用词类型[可以为空] 使用制表符Tab进行分割
* 如:
* #
* v nature
* .*了 regex
*
* */
class StopLibraryRecognition extends Serializable {
def stopRecognitionFilter(arrayList: util.ArrayList[String]): StopRecognition ={
MyStaticValue.isQuantifierRecognition = true //数字和量词合并
val stopRecognition = new StopRecognition
//识别评论中的介词(p)、叹词(e)、连词(c)、代词(r)、助词(u)、字符串(x)、拟声词(o)
stopRecognition.insertStopNatures("p", "e", "c", "r", "u", "x", "o")
stopRecognition.insertStopNatures("w") //剔除标点符号
//剔除以中文数字开头的,以一个字或者两个字为删除单位,超过三个的都不删除
stopRecognition.insertStopRegexes("^一.{0,2}","^二.{0,2}","^三.{0,2}","^四.{0,2}","^五.{0,2}",
"^六.{0,2}","^七.{0,2}","^八.{0,2}","^九.{0,2}","^十.{0,2}")
stopRecognition.insertStopNatures("null") //剔除空
stopRecognition.insertStopRegexes(".{0,1}") //剔除只有一个汉字的
stopRecognition.insertStopRegexes("^[a-zA-Z]{1,}") //把分词只为英文字母的剔除掉
stopRecognition.insertStopWords(arrayList) //添加停用词
stopRecognition.insertStopRegexes("^[0-9]+") //把分词只为数字的剔除
stopRecognition.insertStopRegexes("[^a-zA-Z0-9\u4e00-\\u9fa5]+") //把不是汉字、英文、数字的剔除
stopRecognition
}
def stopLibraryFromHDFS(sparkContext: SparkContext): util.ArrayList[String] ={
/** 获取stop.dic文件中的数据 方法二:
* 在集群上运行的话,需要把stop的数据放在hdfs上,这样集群中所有的节点都能访问到停用词典的数据 */
val stopLib: Array[String] = sparkContext.textFile("hdfs://zysdmaster000:8020/data/library/stop.dic").collect()
val arrayList: util.ArrayList[String] = new util.ArrayList[String]()
for (i<- 0 until stopLib.length)arrayList.add(stopLib(i))
arrayList
}
def getStopLibraryFromLocal: StopRecognition ={
/** 获取stop.dic文件中的数据,此方法不适用在集群上运行 方法一:
* 如果在本地运行的话,可以把stop.dic文件方法src目录(sources格式)下,使用映射方法获取,
* “/library/stop.dic”和“/stop.dic”两种方式均能获取stop.dic中的数据 */
val stopRecognition: StopRecognition = new StopRecognition
stopRecognition.insertStopNatures("w")
val stream: InputStream = StopLibraryRecognition.this.getClass.getResourceAsStream("/stop.dic")
val strings: util.List[String] = IOUtil.readFile2List(stream,IOUtil.UTF8)
stopRecognition.insertStopWords(strings)
}
}
/**用户自定义词典:
* 格式:词语 词性 词频
* 词语、词性和词频用制表符分开(Tab)
*
* */
class DicUserLibrary extends Serializable {
def getUserLibraryList(sparkContext: SparkContext): util.ArrayList[Value] = {
/** 获取userLibrary.dic文件中的数据 方法二:
* 在集群上运行的话,需要把userLibrary的数据放在hdfs上,这样集群中所有的节点都能访问到user library的数据 */
val va: Array[String] = sparkContext.textFile("hdfs://zysdmaster000:8020/data/library/userLibrary.dic").collect()
val arrayList: util.ArrayList[Value] = new util.ArrayList[Value]()
for (i <- 0 until va.length)arrayList.add(new Value(va(i)))
arrayList
}
def getUserLibraryForest: Forest = {
/** 获取userLibrary.dic文件中的数据,此方法不适用在集群上运行 方法一:
* 如果在本地运行的话,可以把userLibrary.dic文件方法src目录(sources格式)下,使用映射方法获取,
* “/library/userLibrary.dic”和“/userLibrary.dic”两种方式均能获取userLibrary.dic中的数据 */
val stream: InputStream = DicUserLibrary.this.getClass.getResourceAsStream("/library/userLibrary.dic")
println(stream)
/** userLibrary.dic文件放在src目录(sources格式)下 */
val forestLibrary: Forest = Library.makeForest(DicUserLibrary.this.getClass.getResourceAsStream("/userLibrary.dic"))
forestLibrary
}
}
TF-IDF计算并调用spark的NaiveBayes方法:
import java.sql.DriverManager
import java.util.Properties
import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, Tokenizer}
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
object DoubanSentimentAnalysis {
def main(args: Array[String]): Unit = {
run
}
def run: Unit ={
val doubanSentimentAnalysis = new DoubanSentimentAnalysis
val table_name_label: Array[String] = Array("short","long")
for (i <- 0 until table_name_label.length) {
doubanSentimentAnalysis.category2Mysql(table_name_label(i))
}
println(".................. douban sentiment analysis has finished ..................")
}
}
class DoubanSentimentAnalysis extends Serializable {
@transient private val sparkConf: SparkConf = new SparkConf().setAppName("douban sentiment analysis")
@transient private val sparkContext: SparkContext = new SparkContext(sparkConf)
private val sqlContext: SQLContext = new SQLContext(sparkContext)
import sqlContext.implicits._
case class RowDataRecord(cid: String, text: String)
def referenceModel(label: String): RDD[(Integer, Int)] = {
val url = "jdbc:mysql://mysql_IP/databaseName?useUnicode=true&characterEncoding=UTF-8"
//把用户名和密码赋予properties
val properties: Properties = new Properties
properties.setProperty("user", "myDatabaseName")
properties.setProperty("password", "myDatabasePassword")
val table_name = label + "_comment_douban_copy"
//根据url、mysql表、properties通过jdbc连接mysql
val comment_douban: DataFrame = sqlContext.read.jdbc(url, table_name, properties)
val dataFrame: DataFrame = comment_douban.select("cid", "comment").repartition(200)
val segment = new ChineseSegment
val segmentDFrame = dataFrame.map { f =>
val str = segment.cNSeg(f.getString(1))
(f(0).toString, str)
}.toDF("cid", "segment")
//对分词结果为空的进行过滤,并把过滤结果存储到case类RawDataRecord中,转换为DataFrame
val partsRDD = segmentDFrame.filter("segment != ''").select("cid", "segment").map { h =>
RowDataRecord(h(0).toString, h(1).toString)
}.toDF()
//把分词结果转换为数组
val tokenizer: Tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
val wordData: DataFrame = tokenizer.transform(partsRDD)
//对分词结果进行TF计算
val hashingTF: HashingTF = new HashingTF().setInputCol("words").setOutputCol("tfFeatures")
val tfDFrame: DataFrame = hashingTF.transform(wordData)
//根据获取的TF值进行IDF计算
val idf: IDF = new IDF().setInputCol("tfFeatures").setOutputCol("rfidfFeatures")
val idfModel: IDFModel = idf.fit(tfDFrame)
val dfidfDFrame: DataFrame = idfModel.transform(tfDFrame)
//评论对应的DF-IDF
val cidDFIDF: DataFrame = dfidfDFrame.select($"cid", $"rfidfFeatures")
//调取已经训练好的模型
val naiveBayesModel: NaiveBayesModel = NaiveBayesModel.load(sparkContext, "hdfs://hdfsMasterIp:/data/naive_bayes_model")
//对每条评论进行分类。vow(0)是评论的id号,vow(1)是评论的tf-idf值
val cidCategory = cidDFIDF.map { vow =>
(Integer.valueOf(vow(0).toString), naiveBayesModel.predict(vow.getAs[Vector](1).toSparse).toInt)
}
cidCategory
}
def category2Mysql(label: String): Unit = {
val tidCategory: RDD[(Integer, Int)] = referenceModel(label)
val mysqlURL = "jdbc:mysql://mysql_IP/database?useUnicode=true&rewriteBatchedStatements=true&characterEncoding=UTF-8"
val mysql = "INSERT INTO "+label+"_comment_douban_copy(cid,attitude) VALUES(?,?) ON DUPLICATE KEY UPDATE cid = values(cid),attitude = values(attitude)"
tidCategory.foreachPartition(it => {
val conn = DriverManager.getConnection(mysqlURL, "databaseName", "databasePassword")
val pstat = conn.prepareStatement(mysql)
for (obj <- it) {
pstat.setInt(1, obj._1)
pstat.setInt(2, obj._2)
//添加到批量中
pstat.addBatch()
}
//批量处理
pstat.executeBatch()
// pstat.close()
// conn.close()
})
}
}