TF-IDF(term frequency–inverse document frequency)是一种用于信息检索与数据挖掘的常用加权技术。TF意思是词频(Term Frequency),IDF意思是逆向文件频率(Inverse Document Frequency)。
思想:对文本进行分词,然后用tfidf算法得到文本对应的词向量,然后利用余弦算法求相似度
需要的jar :je-analysis-1.5.3.jar ,lucene-core-2.4.1.jar(高于4的版本会有冲突)
/**
* 直接匹配2个文本
*
* @author rock
*
*/
public class GetText {
private static List fileList = new ArrayList();
private static HashMap> allTheTf = new HashMap>();
private static HashMap> allTheNormalTF = new HashMap>();
private static LinkedHashMap vectorMap = new LinkedHashMap();
/**
* 分词
*
* @author create by rock
*/
public static String[] TextcutWord(String text) throws IOException {
String[] cutWordResult = null;
MMAnalyzer analyzer = new MMAnalyzer();
String tempCutWordResult = analyzer.segment(text, " ");
cutWordResult = tempCutWordResult.split(" ");
return cutWordResult;
}
public static Map> NormalTFOfAll(String key1, String key2, String text1,
String text2) throws IOException {
if (allTheNormalTF.get(key1) == null) {
HashMap dict1 = new HashMap();
dict1 = normalTF(TextcutWord(text1));
allTheNormalTF.put(key1, dict1);
}
if (allTheNormalTF.get(key2) == null) {
HashMap dict2 = new HashMap();
dict2 = normalTF(TextcutWord(text2));
allTheNormalTF.put(key2, dict2);
}
return allTheNormalTF;
}
public static Map> tfOfAll(String key1, String key2, String text1, String text2)
throws IOException {
allTheTf.clear();
HashMap dict1 = new HashMap();
HashMap dict2 = new HashMap();
dict1 = tf(TextcutWord(text1));
dict2 = tf(TextcutWord(text2));
allTheTf.put(key1, dict1);
allTheTf.put(key2, dict2);
return allTheTf;
}
/**
* 计算词频
*
* @author create by rock
*/
public static HashMap tf(String[] cutWordResult) {
HashMap tf = new HashMap();// 正规化
int wordNum = cutWordResult.length;
int wordtf = 0;
for (int i = 0; i < wordNum; i++) {
wordtf = 0;
if (cutWordResult[i] != " ") {
for (int j = 0; j < wordNum; j++) {
if (i != j) {
if (cutWordResult[i].equals(cutWordResult[j])) {
cutWordResult[j] = " ";
wordtf++;
}
}
}
tf.put(cutWordResult[i], (new Double(++wordtf)) / wordNum);
cutWordResult[i] = " ";
}
}
return tf;
}
public static HashMap normalTF(String[] cutWordResult) {
HashMap tfNormal = new HashMap();// 没有正规化
int wordNum = cutWordResult.length;
int wordtf = 0;
for (int i = 0; i < wordNum; i++) {
wordtf = 0;
if (cutWordResult[i] != " ") {
for (int j = 0; j < wordNum; j++) {
if (i != j) {
if (cutWordResult[i].equals(cutWordResult[j])) {
cutWordResult[j] = " ";
wordtf++;
}
}
}
tfNormal.put(cutWordResult[i], ++wordtf);
cutWordResult[i] = " ";
}
}
return tfNormal;
}
public static Map idf(String key1, String key2, String text1, String text2)
throws FileNotFoundException, UnsupportedEncodingException, IOException {
// 公式IDF=log((1+|D|)/|Dt|),其中|D|表示文档总数,|Dt|表示包含关键词t的文档数量。
Map idf = new HashMap();
List located = new ArrayList();
NormalTFOfAll(key1, key2, text1, text2);
float Dt = 1;
float D = allTheNormalTF.size();// 文档总数
List key = fileList;// 存储各个文档名的List
String[] keyarr = new String[2];
keyarr[0] = key1;
keyarr[1] = key2;
for(String item :keyarr) {
if (!fileList.contains(item)) {
fileList.add(item);
}
}
Map> tfInIdf = allTheNormalTF;// 存储各个文档tf的Map
for (int i = 0; i < D; i++) {
HashMap temp = tfInIdf.get(key.get(i));
for (String word : temp.keySet()) {
Dt = 1;
if (!(located.contains(word))) {
for (int k = 0; k < D; k++) {
if (k != i) {
HashMap temp2 = tfInIdf.get(key.get(k));
if (temp2.keySet().contains(word)) {
located.add(word);
Dt = Dt + 1;
continue;
}
}
}
idf.put(word, (double) Log.log((1 + D) / Dt, 10));
}
}
}
return idf;
}
public static Map> tfidf(String key1, String key2, String text1, String text2)
throws IOException {
Map idf = idf(key1, key2, text1, text2);
tfOfAll(key1, key2, text1, text2);
for (String key : allTheTf.keySet()) {
Map singelFile = allTheTf.get(key);
int length = idf.size();
Double[] arr = new Double[length];
int index = 0;
for (String word : singelFile.keySet()) {
singelFile.put(word, (idf.get(word)) * singelFile.get(word));
}
for (String word : idf.keySet()) {
arr[index] = singelFile.get(word) != null ?singelFile.get(word):0d;
index++;
}
vectorMap.put(key, arr);
}
return allTheTf;
}
/* 得到词向量以后,用余弦相似度匹配 */
public static Double sim(String key1, String key2) {
Double[] arr1 = vectorMap.get(key1);
Double[] arr2 = vectorMap.get(key2);
int length = arr1.length;
Double result1 = 0.00; // 向量1的模
Double result2 = 0.00; // 向量2的模
Double sum = 0d;
if (length == 0) {
return 0d;
}
for (int i = 0; i < length; i++) {
result1 += arr1[i] * arr1[i];
result2 += arr2[i] * arr2[i];
sum += arr1[i] * arr2[i];
}
Double result = Math.sqrt(result1 * result2);
System.out.println(key1 + "和" + key2 + "相似度" + sum / result);
return sum / result;
}
}
匹配多个文件
/**
* 从语料仓库去匹配
* @author rock
*
*/
public class ReadFiles {
private static List fileList = new ArrayList();
private static HashMap> allTheTf = new HashMap>();
private static HashMap> allTheNormalTF = new HashMap>();
private static LinkedHashMap vectorMap = new LinkedHashMap();
/**
* 读取语料仓库
* @author create by rock
*/
public static List readDirs(String filepath) throws FileNotFoundException, IOException {
try {
File file = new File(filepath);
if (!file.isDirectory()) {
System.out.println("输入的参数应该为[文件夹名]");
System.out.println("filepath: " + file.getAbsolutePath());
} else if (file.isDirectory()) {
String[] filelist = file.list();
for (int i = 0; i < filelist.length; i++) {
File readfile = new File(filepath + "\\" + filelist[i]);
if (!readfile.isDirectory()) {
fileList.add(readfile.getAbsolutePath());
} else if (readfile.isDirectory()) {
readDirs(filepath + "\\" + filelist[i]);
}
}
}
} catch (FileNotFoundException e) {
System.out.println(e.getMessage());
}
return fileList;
}
/**
* 读取txt文件
* @author create by rock
*/
public static String readFiles(String file) throws FileNotFoundException, IOException {
StringBuffer sb = new StringBuffer();
InputStreamReader is = new InputStreamReader(new FileInputStream(file), "utf-8");
BufferedReader br = new BufferedReader(is);
String line = br.readLine();
while (line != null) {
sb.append(line).append("\r\n");
line = br.readLine();
}
br.close();
return sb.toString();
}
/**
* 分词
* @author create by rock
*/
public static String[] cutWord(String file) throws IOException {
String[] cutWordResult = null;
String text = ReadFiles.readFiles(file);
MMAnalyzer analyzer = new MMAnalyzer();
String tempCutWordResult = analyzer.segment(text, " ");
cutWordResult = tempCutWordResult.split(" ");
return cutWordResult;
}
/**
* 计算词频
* @author create by rock
*/
public static HashMap tf(String[] cutWordResult) {
HashMap tf = new HashMap();//正规化
int wordNum = cutWordResult.length;
int wordtf = 0;
for (int i = 0; i < wordNum; i++) {
wordtf = 0;
for (int j = 0; j < wordNum; j++) {
if (cutWordResult[i] != " " && i != j) {
if (cutWordResult[i].equals(cutWordResult[j])) {
cutWordResult[j] = " ";
wordtf++;
}
}
}
if (cutWordResult[i] != " ") {
tf.put(cutWordResult[i], (new Float(++wordtf)) / wordNum);
cutWordResult[i] = " ";
}
}
return tf;
}
public static HashMap normalTF(String[] cutWordResult) {
HashMap tfNormal = new HashMap();//没有正规化
int wordNum = cutWordResult.length;
int wordtf = 0;
for (int i = 0; i < wordNum; i++) {
wordtf = 0;
if (cutWordResult[i] != " ") {
for (int j = 0; j < wordNum; j++) {
if (i != j) {
if (cutWordResult[i].equals(cutWordResult[j])) {
cutWordResult[j] = " ";
wordtf++;
}
}
}
tfNormal.put(cutWordResult[i], ++wordtf);
cutWordResult[i] = " ";
}
}
return tfNormal;
}
public static Map> tfOfAll(String dir) throws IOException {
List fileList = ReadFiles.readDirs(dir);
for (String file : fileList) {
HashMap dict = new HashMap();
dict = ReadFiles.tf(ReadFiles.cutWord(file));
allTheTf.put(file, dict);
}
return allTheTf;
}
/**
* 自定义文档内容
* @author create by rock
*/
public static Map> tfOfAll(String[] files) throws IOException {
for (String file : files) {
HashMap dict = new HashMap();
dict = ReadFiles.tf(ReadFiles.cutWord(file));
allTheTf.put(file, dict);
}
return allTheTf;
}
public static Map> NormalTFOfAll(String dir) throws IOException {
List fileList = ReadFiles.readDirs(dir);
for (int i = 0; i < fileList.size(); i++) {
HashMap dict = new HashMap();
dict = ReadFiles.normalTF(ReadFiles.cutWord(fileList.get(i)));
allTheNormalTF.put(fileList.get(i), dict);
}
return allTheNormalTF;
}
public static Map idf(String dir) throws FileNotFoundException, UnsupportedEncodingException, IOException {
//公式IDF=log((1+|D|)/|Dt|),其中|D|表示文档总数,|Dt|表示包含关键词t的文档数量。
Map idf = new HashMap();
List located = new ArrayList();
NormalTFOfAll(dir);
float Dt = 1;
float D = allTheNormalTF.size();//文档总数
List key = fileList;//存储各个文档名的List
Map> tfInIdf = allTheNormalTF;//存储各个文档tf的Map
for (int i = 0; i < D; i++) {
HashMap temp = tfInIdf.get(key.get(i));
for (String word : temp.keySet()) {
Dt = 1;
if (!(located.contains(word))) {
for (int k = 0; k < D; k++) {
if (k != i) {
HashMap temp2 = tfInIdf.get(key.get(k));
if (temp2.keySet().contains(word)) {
located.add(word);
Dt = Dt + 1;
continue;
}
}
}
idf.put(word, Log.log((1 + D) / Dt, 10));
}
}
}
return idf;
}
public static Map> tfidf(String dir) throws IOException {
Map idf = ReadFiles.idf(dir);
Map> tf = ReadFiles.tfOfAll(dir);
for (String file : tf.keySet()) {
Map singelFile = tf.get(file);
int length = idf.size();
Float[] arr = new Float[length];
int index = 0;
for (String word : singelFile.keySet()) {
singelFile.put(word, (idf.get(word)) * singelFile.get(word));
}
for(String word : idf.keySet()) {
if(singelFile.get(word) != null) {
arr[index] = singelFile.get(word);
}else {
arr[index] = 0f;
}
index++;
}
vectorMap.put(file, arr);
}
return tf;
}
public static double sim(String file1,String file2) {
Float [] arr1 = vectorMap.get(file1);
Float [] arr2 = vectorMap.get(file2);
int length = arr1.length;
double result1 = 0.00; //向量1的模
double result2 = 0.00; //向量2的模
Float sum = 0f;
for(int i =0;idouble result = Math.sqrt(result1*result2);
System.out.println(sum/result);
return sum/result;
}
}