libsvm 中文文本分类 java版本

这周打算用word2vec+lstm做一个中文文本分类模型,无奈老大以前用过libsvm,叫我用libsvm,折腾了两天基本上调通

中通碰到各种各样的问题,在此记录下来。


首先下载libsvm包,下载链接

http://www.csie.ntu.edu.tw/~cjlin/cgi-bin/libsvm.cgi?+http://www.csie.ntu.edu.tw/~cjlin/libsvm+zip  libsvm下载工具,下载之后解压 进入目录直接make命令既可以



然后把文本数据规范成如下格式:

2 2017:1.23527900896424 2080:1.3228803416955244 21233:3.475992040593523 
2 576:1.0467435856485432 967:1.0968877798239958 3940:1.7482714392181495 4449:1.7535719911308003 
2 967:1.0968877798239958 1336:1.3551722790297116 5611:1.8303003497257173 14735:1.7682821161365336 
1 7:0.02425295226485008 32:0.009012036411194203 80:0.0057407001135544745 127:0.020374370371014396 

标准的libsvm格式,分词用的是ansj工具,转化数值是tf-idf格式,其中特征的索引一定要按顺序排序,否则用libsvm工具训练的时候会爆如下错误:

Libsvm : Wrong input format at line 1


具体使用可以参考这篇博客:http://endual.iteye.com/blog/1267442,关键是要知道怎么生成libsvm格式文件,这个是关键。



下面贴上把文本转化为libsvm的格式工具的代码,用了许多1.8的特性,习惯了写scala,突然用java感觉很繁琐,见谅:

package com.meituan.model.libsvm;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

import org.ansj.splitWord.analysis.ToAnalysis;
import org.apache.commons.lang3.StringUtils;

import com.meituan.nlp.util.WordUtil;
import com.meituan.nlp.util.TextUtil;
import com.meituan.model.util.Config;

public class DocumentTransForm {
	private static String inputpath = Config.getString("data.path");
	private static String outputpath = Config.getString("data.libsvm");
	private static Map mapTerms = new HashMap();
	public static int documentTotal = 0;

	public static void getTerms(String file) {
		BufferedReader br = null;
		try {
			br = new BufferedReader(new InputStreamReader(new FileInputStream(
					file)));

			String lines = br.readLine();
			int featurecount = 1;
			while (lines != null) {
				String line = lines.split("\t")[0];
				Set sets = ToAnalysis
						.parse(WordUtil.replaceAllSynonyms(TextUtil
								.fan2Jian(WordUtil.replaceAll(line
										.toLowerCase()))))
						.getTerms()
						.stream()
						.map(x -> x.getName())
						.filter(x -> !WordUtil.isStopword(x) && x.length() > 1
								&& !WordUtil.startWithNumeber(x))
						.collect(Collectors.toSet());
				if (sets != null) {

					for (String key : sets) {
						if (!mapTerms.containsKey(key)) {
							Terms terms = new Terms(key, featurecount);
							mapTerms.put(key, terms);
							featurecount++;
						} else {
							mapTerms.get(key).incrFreq();
						}
					}
					documentTotal++;
				}

				lines = br.readLine();

			}

		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			if (br != null) {
				try {
					br.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
			}
		}
	}

	public static void getLibsvmFile(String input, String output) {
		BufferedReader br = null;
		BufferedWriter bw = null;

		try {
			br = new BufferedReader(new InputStreamReader(new FileInputStream(
					input)));

			bw = new BufferedWriter(new OutputStreamWriter(
					new FileOutputStream(output)));

			String lines = br.readLine();

			while (StringUtils.isNoneBlank(lines)) {
				String label = lines.split("\t")[1].equalsIgnoreCase("-1") ? "2"
						: "1";
				String content = lines.split("\t")[0];
				Map maps = ToAnalysis
						.parse(WordUtil.replaceAllSynonyms(TextUtil
								.fan2Jian(WordUtil.replaceAll(content
										.toLowerCase()))))
						.getTerms()
						.stream()
						.map(x -> x.getName())
						.filter(x -> !WordUtil.isStopword(x) && x.length() > 1
								&& !WordUtil.startWithNumeber(x))
						.collect(
								Collectors.groupingBy(p -> p,
										Collectors.counting()));

				if (maps != null && maps.size() > 0) {
					StringBuffer sb = new StringBuffer();
					sb.append(label).append(" ");
					int sum = maps
							.values()
							.stream()
							.reduce((result, element) -> result = result
									+ element).get().intValue();

					Map treeMap = new TreeMap<>();
					for (Entry map : maps.entrySet()) {

						String key = map.getKey();
						double tf = TFIDF.tf(map.getValue(), sum);
						// 这个key一定存在
						double idf = TFIDF.idf(documentTotal, mapTerms.get(key)
								.getFreq());

						treeMap.put(mapTerms.get(key).getId(),
								TFIDF.tfidf(tf, idf));

					}
					treeMap.forEach((x, y) -> sb.append(x).append(":")
							.append(y).append(" "));
					bw.write(sb.toString());
					bw.newLine();
				}

				lines = br.readLine();

			}

		} catch (Exception e) {
			e.printStackTrace();

		} finally {
			try {
				bw.close();
				br.close();

			} catch (Exception e) {
				e.printStackTrace();
			}
		}

	}

	public static void main(String[] args) {

		getTerms(inputpath);
		System.out.println("documentTotal is :" + documentTotal);
		getLibsvmFile(inputpath, outputpath);

		List list = new ArrayList(Arrays.asList("a", "a"));
		Map map = list.stream().collect(
				Collectors.groupingBy(p -> p, Collectors.counting()));
		System.out.println(map.values().stream()
				.reduce((result, element) -> result = result + element).get()
				.intValue());

	}

}



然后开始训练,首先对数据标准化:

./svm-scale  -l 0 -u 1  /Users/shuubiasahi/Documents/workspace/spark-model/file/libsvem.txt >/Users/shuubiasahi/Documents/workspace/spark-model/file/libsvem_scale.txt


开始训练,libsvm提供了若干的参数 ,运行./svm-train,可以看到

./svm-train -h 0 -t 0 /Users/shuubiasahi/Documents/workspace/spark-model/file/libsvem_scale.txt /Users/shuubiasahi/Documents/workspace/spark-model/file/model.txt


svm的理论我个人认为还是比较简单,可以看李航老师那本统计学习方法,一看就明白。







你可能感兴趣的:(机器学习,java)