功能
工具
数据集
结果
引言
- textrank是一个基于词共现的算法,目前最新的spark2.2.1的ml模块里没有textrank。
- python的textrank库:
- 输入是分词后的文本,输出是热度词topN;
- 在对源代码做了些更改后(textrankWeightWords),输出的是热度词topN和其对应的热度值。
脚本
主函数
"""
@author:
@contact:
@file:
@time:
"""
from __future__ import print_function
import sys,os,time,jieba
reload(sys)
sys.setdefaultencoding("utf-8")
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.ml.feature import CountVectorizer
import textrankWeightWords,PreTreatment
url,driver,oracle_user,password,data_table,time_column,DATATYPE_column,SPARK_HOME=PreTreatment.configfileParameter(1)
os.environ['SPARK_HOME'] = SPARK_HOME
spark = SparkSession.builder.appName("textRank").getOrCreate()
sys_year=sys.argv[1]
jdbcDF =spark.read.format("jdbc").options(url=url,driver=driver,dbtable="(select * from " + data_table + " where to_char(" + time_column + ",'yyyy')=" + str(sys_year) + ")",user=oracle_user, password=password).load()
onlychina = jdbcDF.filter(jdbcDF['SS'].rlike("[\u4e00-\u9fa5]"))
onlychina.createOrReplaceTempView("onlychina")
df = spark.sql("select concat_ws(' ', collect_set(SS)) as text_group from onlychina")
rdd = df.rdd.map(lambda x: (Row(keyWord=PreTreatment.stopword(textrankWeightWords.textrankWeightWords(x[0], 30)[1]),hotRate=PreTreatment.intArr2StrArr(textrankWeightWords.textrankWeightWords(x[0], 30)[0]))))
df=rdd.toDF()
df.show(truncate=False)
spark.stop()
textrankWeightWords
"""
@author:
@contact:
@time:
"""
def textrankWeightWords(selfdoc, limit=5, merge=False):
from snownlp import seg
from snownlp import normal
from snownlp.summary import textrank
from snownlp.summary import words_merge
doc = []
sentences = normal.get_sentences(selfdoc)
sents = sentences
for sent in sents:
words = seg.seg(sent)
words = normal.filter_stop(words)
doc.append(words)
rank = textrank.KeywordTextRank(doc)
rank.solve()
ret = []
for w in rank.top_index(limit):
ret.append(w)
if merge:
wm = words_merge.SimpleMerge(selfdoc.doc, ret)
return wm.merge()
weight=[]
for i in rank.top[0:limit]:
weight.append(i[1])
return weight,ret
PreTreatment
"""
@author:
@contact:
@time:
"""
from __future__ import print_function
from pyspark.sql import SparkSession
import os,time,ConfigParser,sys
reload(sys)
sys.setdefaultencoding("utf-8")
os.environ['SPARK_HOME'] = "/usr/local/spark"
spark = SparkSession.builder.appName("PreTreatment").getOrCreate()
sc = spark.sparkContext
stopwords = sc.textFile("hdfs://stopwords.txt")
stopwords = stopwords.collect()
def configfileParameter(b):
pwd = sys.path[0]
path = os.path.abspath(os.path.join(pwd, os.pardir, os.pardir))
os.chdir(path)
cf = ConfigParser.ConfigParser()
cf.read("/configfile.conf")
url = cf.get("oracle", "url")
driver = cf.get("oracle", "driver")
oracle_user = cf.get("oracle", "oracle_user")
password = cf.get("oracle", "password")
data_table = cf.get("oracle", "data_table")
time_column = cf.get("oracle", "time_column")
DATATYPE_column = cf.get("oracle", "DATATYPE_column")
SPARK_HOME = cf.get("SPARK_HOME", "SPARK_HOME")
return url,driver,oracle_user,password,data_table,time_column,DATATYPE_column,SPARK_HOME
def stopword(strArr):
stop_strArr = []
for i in strArr:
if len(i) > 1:
if i.isdigit() != True:
if i not in stopwords:
stop_strArr.append(i)
return stop_strArr
def intArr2StrArr(intArr):
StrArr = []
for i in intArr:
StrArr.append(str(round(i, 4)))
return StrArr
run
/usr/local/spark/bin/spark-submit --jars ojdbc6.jar,mysql-connector-java-5.1.41-bin.jar,textrankWeightWords.py,PreTreatment.py --driver-memory 2G --master spark:// --executor-memory 10G --total-executor-cores 56 Run.py 2018 >> 1.log 2>&1