/**
* 使用 Lingpipe的TF/IDF分类器训练语料
*
* @author laigood
*/
public class TrainTClassifier {
//训练语料文件夹
private static File TDIR = new File("f:\\data\\category");
//定义分类
private static String[] CATEGORIES = { "金融", "军事", "医学", "饮食" };
public static void main(String[] args) throws ClassNotFoundException,
IOException {
TfIdfClassifierTrainer<CharSequence> classifier = new TfIdfClassifierTrainer<CharSequence>(
new TokenFeatureExtractor(CharacterTokenizerFactory.INSTANCE));
// 开始训练
for (int i = 0; i < CATEGORIES.length; i++) {
File classDir = new File(TDIR, CATEGORIES[i]);
if (!classDir.isDirectory()) {
System.out.println("不能找到目录=" + classDir);
}
// 训练器遍历分类文件夹下的所有文件
for (File file : classDir.listFiles()) {
String text = Files.readFromFile(file, "utf-8");
System.out.println("正在训练 " + CATEGORIES[i] + file.getName());
Classification classification = new Classification(
CATEGORIES[i]);
Classified<CharSequence> classified = new Classified<CharSequence>(
text, classification);
classifier.handle(classified);
}
}
// 把分类器模型写到文件上
System.out.println("开始生成分类器");
String modelFile = "f:\\data\\category\\tclassifier";
ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(
modelFile));
classifier.compileTo(os);
os.close();
System.out.println("分类器生成完成");
}
}
TestTClassifier ,测试分类的准确度,测试数据的存放与上面的类似
/**
* 测试TF/IDF分类器的准确度
*
* @author laigood
*/
public class TestTClassifier {
//测试语料的存放目录
private static File TDIR = new File("f:\\data\\test");
private static String[] CATEGORIES = { "金融", "军事", "医学", "饮食" };
public static void main(String[] args) throws ClassNotFoundException {
//分类器模型存放地址
String modelFile = "f:\\data\\category\\tclassifier";
ScoredClassifier<CharSequence> compiledClassifier = null;
try {
ObjectInputStream oi = new ObjectInputStream(new FileInputStream(
modelFile));
compiledClassifier = (ScoredClassifier<CharSequence>) oi
.readObject();
oi.close();
} catch (IOException ie) {
System.out.println("IO Error: Model file " + modelFile + " missing");
}
// 遍历分类目录中的文件测试分类准确度
ConfusionMatrix confMatrix = new ConfusionMatrix(CATEGORIES);
NumberFormat nf = NumberFormat.getInstance();
nf.setMaximumIntegerDigits(1);
nf.setMaximumFractionDigits(3);
for (int i = 0; i < CATEGORIES.length; ++i) {
File classDir = new File(TDIR, CATEGORIES[i]);
//对于每一个文件,通过分类器找出最适合的分类
for (File file : classDir.listFiles()) {
String text = "";
try {
text = Files.readFromFile(file, "utf-8");
} catch (IOException ie) {
System.out.println("不能读取 " + file.getName());
}
System.out.println("测试 " + CATEGORIES[i]
+ File.separator + file.getName());
ScoredClassification classification = compiledClassifier
.classify(text.subSequence(0, text.length()));
confMatrix.increment(CATEGORIES[i],
classification.bestCategory());
System.out.println("最适合的分类: "
+ classification.bestCategory());
}
}
System.out.println("--------------------------------------------");
System.out.println("- 结果 ");
System.out.println("--------------------------------------------");
int[][] imatrix = confMatrix.matrix();
StringBuffer sb = new StringBuffer();
sb.append(StringTools.fillin("CATEGORY", 10, true, ' '));
for (int i = 0; i < CATEGORIES.length; i++)
sb.append(StringTools.fillin(CATEGORIES[i], 8, false, ' '));
System.out.println(sb.toString());
for (int i = 0; i < imatrix.length; i++) {
sb = new StringBuffer();
sb.append(StringTools.fillin(CATEGORIES[i], 10, true, ' ',
10 - CATEGORIES[i].length()));
for (int j = 0; j < imatrix.length; j++) {
String out = "" + imatrix[i][j];
sb.append(StringTools.fillin(out, 8, false, ' ',
8 - out.length()));
}
System.out.println(sb.toString());
}
System.out.println("准确度: "
+ nf.format(confMatrix.totalAccuracy()));
System.out.println("总共正确数 : " + confMatrix.totalCorrect());
System.out.println("总数:" + confMatrix.totalCount());
}
}
补上StringTools
/** * A class containing a bunch of string utilities - <br> * a. filterChars: Remove extraneous characters from a string and return a * "clean" string. <br> * b. getSuffix: Given a file name return its extension. <br> * c. fillin: pad or truncate a string to a fixed number of characters. <br> * d. removeAmpersandStrings: remove strings that start with ampersand <br> * e. shaDigest: Compute the 40 byte digest signature of a string <br> */ public class StringTools { public static final Locale LOCALE = new Locale("en"); // * -- String limit for StringTools private static int STRING_TOOLS_LIMIT = 1000000; // *-- pre-compiled RE patterns private static Pattern extPattern = Pattern.compile("^.*[.](.*?){1}quot;); private static Pattern spacesPattern = Pattern.compile("\\s+"); private static Pattern removeAmpersandPattern = Pattern.compile("&[^;]*?;"); /** * Removes non-printable spaces and replaces with a single space * * @param in * String with mixed characters * @return String with collapsed spaces and printable characters */ public static String filterChars(String in) { return (filterChars(in, "", ' ', true)); } public static String filterChars(String in, boolean newLine) { return (filterChars(in, "", ' ', newLine)); } public static String filterChars(String in, String badChars) { return (filterChars(in, badChars, ' ', true)); } public static String filterChars(String in, char replaceChar) { return (filterChars(in, "", replaceChar, true)); } public static String filterChars(String in, String badChars, char replaceChar, boolean newLine) { if (in == null) return ""; int inLen = in.length(); if (inLen > STRING_TOOLS_LIMIT) return in; try { // **-- replace non-recognizable characters with spaces StringBuffer out = new StringBuffer(); int badLen = badChars.length(); for (int i = 0; i < inLen; i++) { char ch = in.charAt(i); if ((badLen != 0) && removeChar(ch, badChars)) { ch = replaceChar; } else if (!Character.isDefined(ch) && !Character.isSpaceChar(ch)) { ch = replaceChar; } out.append(ch); } // *-- replace new lines with space Matcher matcher = null; in = out.toString(); // *-- replace consecutive spaces with single space and remove // leading/trailing spaces in = in.trim(); matcher = spacesPattern.matcher(in); in = matcher.replaceAll(" "); } catch (OutOfMemoryError e) { return in; } return in; } // *-- remove any chars found in the badChars string private static boolean removeChar(char ch, String badChars) { if (badChars.length() == 0) return false; for (int i = 0; i < badChars.length(); i++) { if (ch == badChars.charAt(i)) return true; } return false; } /** * Return the extension of a file, if possible. * * @param filename * @return string */ public static String getSuffix(String filename) { if (filename.length() > STRING_TOOLS_LIMIT) return (""); Matcher matcher = extPattern.matcher(filename); if (!matcher.matches()) return ""; return (matcher.group(1).toLowerCase(LOCALE)); } public static String fillin(String in, int len) { return fillin(in, len, true, ' ', 3); } public static String fillin(String in, int len, char fillinChar) { return fillin(in, len, true, fillinChar, 3); } public static String fillin(String in, int len, boolean right) { return fillin(in, len, right, ' ', 3); } public static String fillin(String in, int len, boolean right, char fillinChar) { return fillin(in, len, right, fillinChar, 3); } /** * Return a string concatenated or padded to the specified length * * @param in * string to be truncated or padded * @param len * int length for string * @param right * boolean fillin from the left or right * @param fillinChar * char to pad the string * @param numFills * int number of characters to pad * @return String of specified length */ public static String fillin(String in, int len, boolean right, char fillinChar, int numFills) { // *-- return if string is of required length int slen = in.length(); if ((slen == len) || (slen > STRING_TOOLS_LIMIT)) return (in); // *-- build the fillin string StringBuffer fillinStb = new StringBuffer(); for (int i = 0; i < numFills; i++) fillinStb.append(fillinChar); String fillinString = fillinStb.toString(); // *-- truncate and pad string if length exceeds required length if (slen > len) { if (right) return (in.substring(0, len - numFills) + fillinString); else return (fillinString + in.substring(slen - len + numFills, slen)); } // *-- pad string if length is less than required length DatabaseEntry // dbe = dbt.getNextKey(); String dbkey = new String (dbe.getData()); StringBuffer sb = new StringBuffer(); if (right) sb.append(in); sb.append(fillinString); if (!right) sb.append(in); return (sb.toString()); } /** * Remove ampersand strings such as \ * * @param in * Text string extracted from Web pages * @return String Text string without ampersand strings */ public static String removeAmpersandStrings(String in) { if (in.length() > STRING_TOOLS_LIMIT) return (in); Matcher matcher = removeAmpersandPattern.matcher(in); return (matcher.replaceAll("")); } /** * Escape back slashes * * @param in * Text to be escaped * @return String Escaped test */ public static String escapeText(String in) { StringBuffer sb = new StringBuffer(); for (int i = 0; i < in.length(); i++) { char ch = in.charAt(i); if (ch == '\\') sb.append("\\\\"); else sb.append(ch); } return (sb.toString()); } /** * Get the SHA signature of a string * * @param in * String * @return String SHA signature of in */ public static String shaDigest(String in) { StringBuffer out = new StringBuffer(); if ((in == null) || (in.length() == 0)) return (""); try { // *-- create a message digest instance and compute the hash // byte array MessageDigest md = MessageDigest.getInstance("SHA-1"); md.reset(); md.update(in.getBytes()); byte[] hash = md.digest(); // *--- Convert the hash byte array to hexadecimal format, pad // hex chars with leading zeroes // *--- to get a signature of consistent length (40) for all // strings. for (int i = 0; i < hash.length; i++) { out.append(fillin(Integer.toString(0xFF & hash[i], 16), 2, false, '0', 1)); } } catch (OutOfMemoryError e) { return ("<-------------OUT_OF_MEMORY------------>"); } catch (NoSuchAlgorithmException e) { return ("<------SHA digest algorithm not found--->"); } return (out.toString()); } /** * Return the string with the first letter upper cased * * @param in * @return String */ public static String firstLetterUC(String in) { if ((in == null) || (in.length() == 0)) return (""); String out = in.toLowerCase(LOCALE); String part1 = out.substring(0, 1); String part2 = out.substring(1, in.length()); return (part1.toUpperCase(LOCALE) + part2.toLowerCase(LOCALE)); } /** * Return a pattern that can be used to collapse consecutive patterns of the * same type * * @param entityTypes * A list of entity types * @return Regex pattern for the entity types */ public static Pattern getCollapsePattern(String[] entityTypes) { Pattern collapsePattern = null; StringBuffer collapseStr = new StringBuffer(); for (int i = 0; i < entityTypes.length; i++) { collapseStr.append("(<\\/"); collapseStr.append(entityTypes[i]); collapseStr.append(">\\s+"); collapseStr.append("<"); collapseStr.append(entityTypes[i]); collapseStr.append(">)|"); } collapsePattern = Pattern.compile(collapseStr.toString().substring(0, collapseStr.length() - 1)); return (collapsePattern); } /** * return a double that indicates the degree of similarity between two strings * Use the Jaccard similarity, i.e. the ratio of A intersection B to A union B * * @param first * string * @param second * string * @return double degreee of similarity */ public static double stringSimilarity(String first, String second) { if ((first == null) || (second == null)) return (0.0); String[] a = first.split("\\s+"); String[] b = second.split("\\s+"); // *-- compute a union b HashSet<String> aUnionb = new HashSet<String>(); HashSet<String> aTokens = new HashSet<String>(); HashSet<String> bTokens = new HashSet<String>(); for (int i = 0; i < a.length; i++) { aUnionb.add(a[i]); aTokens.add(a[i]); } for (int i = 0; i < b.length; i++) { aUnionb.add(b[i]); bTokens.add(b[i]); } int sizeAunionB = aUnionb.size(); // *-- compute a intersect b Iterator <String> iter = aUnionb.iterator(); int sizeAinterB = 0; while (iter != null && iter.hasNext()) { String token = (String) iter.next(); if (aTokens.contains(token) && bTokens.contains(token)) sizeAinterB++; } return ((sizeAunionB > 0) ? (sizeAinterB + 0.0) / sizeAunionB : 0.0); } /** * Return the edit distance between the two strings * * @param s1 * @param s2 * @return double */ public static double editDistance(String s1, String s2) { if ((s1.length() == 0) || (s2.length() == 0)) return (0.0); return EditDistance.editDistance(s1.subSequence(0, s1.length()), s2 .subSequence(0, s2.length()), false); } /** * Return a string with the contents from the passed reader * * @param r Reader * @return String */ public static String readerToString(Reader r) { int charValue; StringBuffer sb = new StringBuffer(1024); try { while ((charValue = r.read()) != -1) sb.append((char) charValue); } catch (IOException ie) { sb.setLength(0); } return (sb.toString()); } /** * Clean up a sentence by consecutive non-alphanumeric chars with a single * non-alphanumeric char * * @param in Array of chars * @return String */ public static String cleanString(char[] in) { int len = in.length; boolean prevOK = true; for (int i = 0; i < len; i++) { if (Character.isLetterOrDigit(in[i]) || Character.isWhitespace(in[i])) prevOK = true; else { if (!prevOK) in[i] = ' '; prevOK = false; } } return (new String(in)); } /** * Return a clean file name * * @param filename * @return String */ public static String parseFile(String filename) { return (filterChars(filename, "\\/_:.")); } }