20W人工标注文本数据,样本如下:
1#k-v#*亮亮爱宠*波波宠物指甲钳指甲剪附送锉刀适用小型犬及猫特价
1#k-v#*顺丰包邮*宠物药品圣马利诺PowerIgG免疫力球蛋白犬猫细小病毒
1#k-v#*包邮*法国罗斯蔓草本精华宠物浴液薰衣草护色润泽香波拍套餐
1#k-v#*包邮*家朵102宠物沐浴液
1#k-v#*包邮*家朵102宠物沐浴液猫
使用ansj包对文本数据去除停用词分词。代码如下:
import java.io.File;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.ansj.domain.Result;
import org.ansj.domain.Term;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
public class Seg{
private static Set stopwords = new HashSet();
static{
File f = new File("");
try {
List lines = FileUtils.readLines(f);
for(String str : lines){
stopwords.add(str);
}
} catch (IOException e) {
e.printStackTrace();
}
}
public static void main(String[] args) throws IOException {
File f = new File("");
File resultFile = new File("");
List lists = FileUtils.readLines(f);
int count = 0;
for(String str : lists){
count++;
String index = str.split("#k-v#")[0];
// System.out.println(count + " " + Integer.parseInt(index));
Result res = ToAnalysis.parse(str.split("#k-v#")[1]);
List terms = res.getTerms();
String wordStr = "";
for(Term t : terms){
String word = t.getName();
if(word.length()>1&&!stopwords.contains(word)){
wordStr = wordStr + " " + word;
}
}
if(!StringUtils.isEmpty(wordStr)){
FileUtils.write(resultFile, index + "#k-v#" + wordStr + "\n" , true);
}
System.out.println(count);
}
}
这里我用到工具是sparkmllib的tfidf带的包,代码如下:
import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.Row
//case class FileRecord(index:Int,seg: String)
object TfIdf {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("TfIdfExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
val schemaString = "index seg"
val fields = schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, nullable = true))
val schema = StructType(fields)
val srcRDD = sc.textFile("/tmp/seg_src.txt", 1).map(x => x.split("#k-v#")).map(attributes => Row(attributes(0), attributes(1).trim))
val sentenceData = sqlContext.createDataFrame(srcRDD, schema).toDF("label", "seg")
val tokenizer = new Tokenizer().setInputCol("seg").setOutputCol("words")
val wordsData = tokenizer.transform(sentenceData)
val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(26)
val featurizedData = hashingTF.transform(wordsData)
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
val idfModel = idf.fit(featurizedData)
val rescaledData = idfModel.transform(featurizedData)
rescaledData.select("features", "label").take(3).foreach(println)
rescaledData.select("features", "label").write.format("json").save("/tmp/tfidf.model")
}
}
得到的是json数据格式,示例数据如下:
{"features":{"type":0,"size":26,"indices":[0,5,6,7,9,10,14,17,21],"values":[2.028990788466258,1.8600672974067514,1.8464729103095205,2.037399707294254,1.908861495143531,3.6260607728633083,2.0363086347259687,1.8261747092361593,2.0640809711702492]},"label":"1"}
{"features":{"type":0,"size":26,"indices":[7,8,17],"values":[4.074799414588508,2.1216332358971366,1.8261747092361593]},"label":"1"}
因为sparkmllib中随机森林算法需libsvm数据格式,故进行转换,代码如下:
File f = new File("D:/sogouOutput/json_feature");
File libsvmFile = new File("D:/sogouOutput/libsvm_feature");
List features = FileUtils.readLines(f);
for(String str : features){
JSONObject obj = new JSONObject(str);
String label = obj.getString("label");
JSONArray indexArr = obj.getJSONObject("features").getJSONArray("indices");
JSONArray valueArr = obj.getJSONObject("features").getJSONArray("values");
int length = indexArr.length();
String line = label + " ";
Map indiceAndValue = new TreeMap();
for(int i=0;i
indiceAndValue.put(indexArr.getInt(i), valueArr.getDouble(i));
// line = line + indexArr.getInt(i)+":" + valueArr.getDouble(i) + " ";
}
//特征索引不能为0,不知为什么。
if(indiceAndValue.containsKey(0)){
indiceAndValue.remove(0);
}
for(Map.Entry m : indiceAndValue.entrySet()){
line = line + m.getKey()+":" + m.getValue() + " ";
}
// System.out.println(StringUtils.substring(line, 0, -1));
FileUtils.write(libsvmFile, StringUtils.substring(line, 0, -1) + "\n", true);
}
结果示例数据如下:
1 7:2.037399707294254
1 1:1.6033119355738932 5:1.8600672974067514 7:4.074799414588508 10:1.8130303864316542 13:2.0344821501999344 15:2.2043195316439834 18:2.0104112775954426 20:2.0108489143639154 25:1.9189925465072746
1 3:5.510668692397079 5:1.8600672974067514 6:1.8464729103095205 7:4.074799414588508 17:1.8261747092361593
1 3:5.510668692397079 5:1.8600672974067514 6:1.8464729103095205 7:2.037399707294254 13:2.0344821501999344 17:1.8261747092361593 20:2.0108489143639154
1 7:2.037399707294254
分类代码如下:
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
// $example off$
object RandomForestClassifierExample {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("RandomForestClassifierExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
// $example on$
// Load and parse the data file, converting it to a DataFrame.
val data = sqlContext.read.format("libsvm").load("/tmp/libsvm_feature")
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
//待征索引必须升序
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(26).fit(data)
// Split the data into training and test sets (30% held out for testing)
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// Train a RandomForest model.
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10)
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
// Chain indexers and forest in a Pipeline
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)
// Make predictions.
val predictions = model.transform(testData)
// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)
// Select (prediction, true label) and compute test error
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
println("Learned classification forest model:\n" + rfModel.toDebugString)
// $example off$
sc.stop()
}
}
在运行过程中,val labelIndexer = new StringIndexer().setInputCol(“label”).setOutputCol(“indexedLabel”).fit(data)
这句代码会报错:
Caused by: java.lang.IllegalArgumentException: requirement failed: indices should be one-based and in ascending order
经查找是因为特征索引不能为0,看它源代码是index作了-1处理导致的。
private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
val items = line.split(' ')
val label = items.head.toDouble
val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
// check if indices are one-based and in ascending order
var previous = -1
var i = 0
val indicesLength = indices.length
while (i < indicesLength) {
val current = indices(i)
require(current > previous, s"indices should be one-based and in ascending order;"
+ s""" found current=$current, previous=$previous; line="$line"""")
previous = current
i += 1
}
(label, indices.toArray, values.toArray)
}