java实现,使用向量相似度 输入字符串,在定义好的字符串集合中根据语义匹配出最准的一个。

以下是完整的 Java 示例代码,包括字符串集合的定义和根据输入字符串匹配最相似字符串的逻辑:

import java.util.*;

public class SemanticMatching {

    public static void main(String[] args) {
        // 定义字符串集合
        List stringCollection = Arrays.asList(
                "Where is the restroom?",
                "Can you tell me where the bathroom is?",
                "How do I get to the nearest toilet?",
                "Is there a sign for the restroom around here?",
                "Could you please point me to the nearest lavatory?",
                "Where is the front desk?",
                "Is the front desk nearby?",
                "Where can I find the lobby desk?",
                "Could you please direct me to the front counter?"
        );

        // 输入字符串
        String inputString = "Can you help me find the reception?";

        // 找出最相似的字符串
        String mostSimilarString = findMostSimilarString(inputString, stringCollection);
        System.out.println("Most Similar String: " + mostSimilarString);
    }

    public static String findMostSimilarString(String input, List collection) {
        double maxSimilarity = 0.5;
        String mostSimilar = null;

        for (String candidate : collection) {
            double similarity = calculateCosineSimilarity(input, candidate);
            System.out.println("candidate = " + candidate +"; similarity = " + similarity);
            if (similarity > maxSimilarity) {
                maxSimilarity = similarity;
                mostSimilar = candidate;
            }
        }

        return mostSimilar;
    }

    public static double calculateCosineSimilarity(String text1, String text2) {
        Map vector1 = getWordFrequency(text1);
        Map vector2 = getWordFrequency(text2);

        // 获取所有的词
        HashMap allWords = new HashMap<>(vector1);
        allWords.putAll(vector2);

        double[] vec1 = new double[allWords.size()];
        double[] vec2 = new double[allWords.size()];

        int index = 0;
        for (String word : allWords.keySet()) {
            vec1[index] = vector1.getOrDefault(word, 0);
            vec2[index] = vector2.getOrDefault(word, 0);
            index++;
        }

        return calculateCosine(vec1, vec2);
    }

    private static Map getWordFrequency(String text) {
        String[] words = text.toLowerCase().replaceAll("[^a-z ]", "").split("\\s+");
        Map frequencyMap = new HashMap<>();
        for (String word : words) {
            frequencyMap.put(word, frequencyMap.getOrDefault(word, 0) + 1);
        }
        return frequencyMap;
    }

    private static double calculateCosine(double[] vec1, double[] vec2) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;

        for (int i = 0; i < vec1.length; i++) {
            dotProduct += vec1[i] * vec2[i];
            normA += Math.pow(vec1[i], 2);
            normB += Math.pow(vec2[i], 2);
        }
        normA = Math.sqrt(normA);
        normB = Math.sqrt(normB);
        return (normA == 0 || normB == 0) ? 0.0 : dotProduct / (normA * normB);
    }
}

你可能感兴趣的:(1024程序员节)