语音转写正确率计算

语音转文字准确率计算 转写正确率的衡量项:ACC、Corr
H = 正确的字数
D = 删减的错误,“我是中国人” “我是中国”
S = 替换的错误,“我是中国人” “我是中华人”
I = 插入的错误,“我是中国人” “我是中国男人”
N = 总字数
ACC = (H - I)/N
Corr= H/N

思路:
sample "我是中国人学大生,今天要测试,录音结束"
test "中国人民生,今天有个大事情,我想吃饭"

语音转写文字,需要遵从文字的语义,所以不能文字出现就算正确,不考虑各种复杂的因素,
要从test中找到sample中对应的字符,并且顺序要按sample中的字符顺序(排除标点符号)

如:
sample中每个字找到的顺序


image.png

那么如何找到正确的文字个数呢?应该就是在这组序号中剔除未找到的-1,然后从剩下的序号中找到最长升序序列(竟然是个算法问题,丢!)


image.png

如图,剔除-1,剩下的就是12,0,1,2,9,4,5,6,最长升序那不就是0,1,2,4,5,6吗,对应的文字就是“中国人生今天”,那么算法如何实现呢?

这里给出个笨办法:
遍历序列,将每个升序数组都保存在list里,如果不是升序,就新建一个list。
如当遇到12,则新建一个list:


image.png

当遇到0,则需要新建一个list


image.png

另外还需要注意,即使是升序,也不一定就能组成最长
如0 1 2,后面是9,如果加进去了,就只能组成0 1 2 9,长度为4,如果放弃加9,则有机会组成0 1 2 4 5 6,长度为6。所以每次添加一个符合升序规则的数字的时候,我们要提前将原list备份一个,留个机会看能否组成更长的序列。

剩下就是计算删减、替换、添加的文字个数了,这个比较简单,看看代码逻辑就行了。
本文代码没有考虑时间复杂度和内存占用(用于测试),请自行优化。
上代码:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Set;
import java.util.function.IntPredicate;
import java.util.logging.Logger;

/**
 * 语音转文字准确率计算 转写正确率的衡量项:ACC、Corr H = 正确的字数 D = 删减的错误,“我是中国人”“我是中国” S =
 * 替换的错误,“我是中国人”“我是中华人” I = 插入的错误,“我是中国人”“我是中国男人” N = 总字数 ACC = (H - I)/N
 * Corr= H/N
 */

public class StringCompare {
    private static final String ORIGINAL_STRING = "我是中国人学大生,今天要测试,录音结束";

    private static final String TEST_STRING = "中国人民生,今天有个大事情,我想吃饭";

    private static float mAcc = 0;
    private static float mCorr = 0;

    private static int H = 0;// 正确的字数
    private static int N = 0;// 总字数
    private static int D = 0;// 删减的字数
    private static int S = 0;// 替换的字数
    private static int I = 0;// 插入的字数

    private static List arrayD = new ArrayList();
    private static List arrayI = new ArrayList();
    private static List arrayS = new ArrayList();

    /**
     * 移除标点符号
     */
    private static String removePunctuation(String sentence) {
        String string = sentence.replaceAll("\\pP", "");// 完全清除标点
        System.out.println(string);
        return string;
    }

    private static void print() {
        mAcc = (float) (H - I) / (float) N;
        mCorr = H / (float) N;
        System.out.println("H:" + H + ", N:" + N + ", D:" + D + ", I:" + I + ", S:" + S);
        System.out.println("mAcc" + mAcc + ",mCorr:" + mCorr);
    }

    /*
     * 获取A字符串每个字符在B字符串中的位置
     */
    private static int[] getInIndex(String original, String test) {
        char[] charOriginal = original.toCharArray();
        List list = new ArrayList();
        for (int i = 0; i < charOriginal.length; i++) {
            int tempIndex = test.indexOf(charOriginal[i]);
            System.out.println("tempIndex:" + tempIndex);
            while (list.contains(tempIndex) && tempIndex != -1) {
                System.out.println("while");
                tempIndex = test.indexOf(charOriginal[i], tempIndex + 1);
                System.out.println("tempIndex:" + tempIndex);
            }
            list.add(tempIndex);
            System.out.println("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
        }
        return list.stream().mapToInt(Integer::valueOf).toArray();
    }

    // 找出最长的有序list
    private static Map getMaxSortedIndexs(int[] index) {
        List> list = new ArrayList>();
        for (int i = 0; i < index.length; i++) {

            if (index[i] >= 0) {
                System.out.println("cur index:" + index[i]);
                if (list.size() == 0) {
                    list.add(new LinkedHashMap());
                    list.get(list.size() - 1).put(i, index[i]);
                } else {
                    ListIterator> it = list.listIterator();
                    while (it.hasNext()) {
                        Map everyList = it.next();
                        // 与最后一个元素比较
                        if (index[i] > (int) everyList.values().toArray()[everyList.size() - 1]) {
                            Map tempList = new LinkedHashMap();
                            tempList.putAll(everyList);
                            everyList.put(i, index[i]);
                            it.add(tempList);
                        } else {
                            Map tempList = new LinkedHashMap();
                            tempList.put(i, index[i]);
                            it.add(tempList);
                        }
                    }
                }
            }
        }
        System.out.println("list size:" + list.size());
        int longestIndex = 0;
        int maxlength = 0;
        for (int i = 0; i < list.size(); i++) {
            System.out.println("list[" + i + "]:" + list.get(i).toString());
            if (list.get(i).size() > maxlength) {
                maxlength = list.get(i).size();
                longestIndex = i;
            }
        }
        return list.get(longestIndex);
    }

    public static void main(String[] args) {
        long time = System.currentTimeMillis();
        // 拿到每个字符在测试字符串中的位置
        String original = removePunctuation(ORIGINAL_STRING);
        String test = removePunctuation(TEST_STRING);
        int[] index = getInIndex(original, test);
        for (int i = 0; i < index.length; i++) {
            System.out.print(index[i] + "\t");
        }
        System.out.print("\n");
        Map sortedIndex = getMaxSortedIndexs(index);
        System.out.println("cost:" + (System.currentTimeMillis() - time));
        Set keySet = sortedIndex.keySet();
        Collection valueSet = sortedIndex.values();
        System.out.println("在original字符串中正确的字符");
        System.out.println("[" + ORIGINAL_STRING + "]");
        for (int key : keySet) {
            System.out.println("index:" + key + "value:" + original.charAt(key));
        }
        System.out.println("--------------------------------------------------------------");
        System.out.print("在test字符串中正确的字符\n");
        System.out.println("[" + TEST_STRING + "]");
        for (int key : valueSet) {
            System.out.println("index:" + key + "value:" + test.charAt(key));
        }
        System.out.println("--------------------------------------------------------------");
        System.out.print("计算准确率\n");

        // 每一段的差值都需要计算
        int tempkey = 0;
        int tempValue = 0;
        // 保存每一段的A字符序号的差
        int diffKey = 0;
        // 保存每一段的B字符序号的差
        int diffValue = 0;
        int indexFlag == 0;
        for (Integer key : sortedIndex.keySet()) {
            int value = sortedIndex.get(key);
            if (tempkey == 0 && key != 0 && indexFlag == 0) {
                diffKey = key - tempkey;
            } else if (key == 0) {
                diffKey = 0;
            } else {
                diffKey = key - tempkey - 1;
            }
            if (tempValue == 0 && value != 0 && indexFlag  == 0) {
                diffValue = value - tempValue;
            } else if (value == 0) {
                diffValue = 0;
            } else {
                diffValue = value - tempValue - 1;
            }
            System.err.println("diffKey:" + diffKey + ", diffValue:" + diffValue);
            if (diffKey > diffValue) {
                D += diffKey - diffValue;
                S += diffValue;
            } else if (diffKey == diffValue) {
                S += diffValue;
            } else {
                I += diffValue - diffKey;
                S += diffKey;
            }
            tempkey = key;
            tempValue = value;
            indexFlag ++;
        }
        System.out.println("tempkey:" + tempkey + ",tempValue:" + tempValue);
        diffKey = original.length() - tempkey - 1;
        diffValue = test.length() - tempValue - 1;
        System.out.println("diffKey:" + diffKey + ",diffValue:" + diffValue);
        if (diffKey > diffValue) {
            D += diffKey - diffValue;
            S += diffValue;
        } else if (diffKey == diffValue) {
            S += diffValue;
        } else {
            I += diffValue - diffKey;
            S += diffKey;
        }
        H = sortedIndex.size();
        N = test.length();
        print();
        System.out.println("--------------------------------------------------------------");
    }
}

C#版本

using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using System.Linq;

namespace Calibration.utils
{
    class EsrUtils
    {
        private EsrUtils() { }

        /*private static readonly EsrUtils singleInstance = new EsrUtils();

        public static EsrUtils GetInstance
        {
            get
            {
                return singleInstance;
            }
        }*/

        //移除标点符号
        public static string RemovePunctuation(string sentence)
        {
            return Regex.Replace(sentence, "[ \\[ \\] \\^ \\-_*×――(^)$%~!@#$…&%¥—+=<>《》!!???::•`·、。,;,.;\"‘’“”-]", "");
        }


        // 获取A字符串每个字符在B字符串中的位置
        public static int[] GetInIndex(String original, String test)
        {
            char[] charOriginal = original.ToCharArray();
            var list = new List();
            for (int i = 0; i < charOriginal.Length; i++)
            {
                int tempIndex = test.IndexOf(charOriginal[i]);
                Console.WriteLine("tempIndex:" + tempIndex);
                while (list.Contains(tempIndex) && tempIndex != -1)
                {
                    Console.WriteLine("while");
                    tempIndex = test.IndexOf(charOriginal[i], tempIndex + 1);
                    Console.WriteLine("tempIndex:" + tempIndex);
                }
                list.Add(tempIndex);
                Console.WriteLine("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
            }
            return list.ToArray();
        }

        // 找出序号数组中最长的升序子序列
        // 目的
        public static Dictionary GetMaxSortedIndexs(int[] index)
        {
            List> list = new List> { };
            for (int i = 0; i < index.Length; i++)
            {

                if (index[i] >= 0)
                {
                    Console.WriteLine("cur index:" + index[i]);
                    if (list.Count == 0)
                    {
                        list.Add(new Dictionary());
                        list.Last().Add(i, index[i]);
                    }
                    else
                    {
                        List> listBackup = new List> { };
                        for (int j = 0; j < list.Count; j++)
                        {
                            Dictionary everyList = list[j];
                            // 与最后一个元素比较
                            if (index[i] > everyList.Values.Last())
                            {
                                // 将当前Dictionary备份一个,因为当前的数据添加或者不添加会有两种结果
                                // 如数组 12 0 1 2 9 4 5 6
                                // 如果 0 1 2 后加了9,那只有 0 1 2 9长度为4
                                // 如果 0 1 2 不加9,那就有0 1 2 4 5 6,长度为6
                                Dictionary tempList = new Dictionary(everyList);
                                listBackup.Add(tempList);
                                everyList.Add(i, index[i]);
                            }
                            else
                            {
                                Dictionary tempList = new Dictionary();
                                tempList.Add(i, index[i]);
                                listBackup.Add(tempList);
                            }
                        }
                        // list 合并
                        list = list.Union(listBackup).ToList>();
                    }
                }
            }
            Console.WriteLine("list size:" + list.Count);
            int longestIndex = 0;
            int maxlength = 0;
            for (int i = 0; i < list.Count; i++)
            {
                Console.WriteLine("list[" + i + "]:" + list[i]);
                if (list[i].Count > maxlength)
                {
                    maxlength = list[i].Count;
                    longestIndex = i;
                }
            }
            return list[longestIndex];
        }

        public static List GetParameters(string originalString, List testString)
        {
            List resultList = new List { };
            try
            {
                string original = RemovePunctuation(originalString);
                for (int i = 0; i < testString.Count; i++)
                {
                    string test = testString[i];
                    test = RemovePunctuation(test);
                    int[] index = EsrUtils.GetInIndex(original, test);
                    Console.WriteLine("index:" + index);
                    Dictionary dic = EsrUtils.GetMaxSortedIndexs(index);
                    int H = 0;
                    int N = 0;
                    int I = 0;
                    int S = 0;
                    int D = 0;
                    float corr = 0;
                    float acc = 0;
                    Dictionary sortedIndex = GetMaxSortedIndexs(index);
                    Console.WriteLine("在original字符串中正确的字符");
                    Console.WriteLine("[" + originalString + "]");
                    foreach (int key in dic.Keys)
                    {
                        Console.WriteLine("index:" + key + "value:" + original.ToCharArray()[key]);
                    }
                    Console.WriteLine("--------------------------------------------------------------");
                    Console.WriteLine("在test字符串中正确的字符\n");
                    Console.WriteLine("[" + testString[i] + "]");
                    foreach (int key in dic.Values)
                    {
                        Console.WriteLine("index:" + key + "value:" + test.ToCharArray()[key]);
                    }
                    Console.WriteLine("--------------------------------------------------------------");
                    Console.WriteLine("计算准确率\n");

                    // 每一段的差值都需要计算
                    /*
                    int tempkey = 0;
                    int tempValue = 0;
                    int diffKey = 0;
                    int diffValue = 0;
                    foreach (int key in sortedIndex.Keys)
                    {
                        int value = sortedIndex[key];
                        diffKey = key - tempkey;
                        diffValue = value - tempValue;
                        if (diffKey > diffValue)
                        {
                            D += diffKey - diffValue;
                            S += diffValue;
                        }
                        else if (diffKey == diffValue)
                        {
                            S += diffValue;
                        }
                        else
                        {
                            I += diffValue - diffKey;
                            S += diffKey;
                        }
                        tempkey = key;
                        tempValue = value;
                    }
                    Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
                    diffKey = original.Length - tempkey;
                    diffValue = test.Length - tempValue;
                    if (diffKey > diffValue)
                    {
                        D += diffKey - diffValue;
                        S += diffValue;
                    }
                    else if (diffKey == diffValue)
                    {
                        S += diffValue;
                    }
                    else
                    {
                        I += diffValue - diffKey;
                        S += diffKey;
                    }
                    H = sortedIndex.Count;
                    N = test.Length;
                    */
                    // 每一段的差值都需要计算
                    int tempkey = 0;
                    int tempValue = 0;
                    // 保存每一段的A字符序号的差
                    int diffKey = 0;
                    // 保存每一段的B字符序号的差
                    int diffValue = 0;
                    int indexFlag == 0;
                    foreach (int key in sortedIndex.Keys)
                    {
                        int value = sortedIndex[key];
                        if (tempkey == 0 && key != 0 && indexFlag  == 0) //判断第一位元素
                        {
                            diffKey = key - tempkey;
                        }
                        else if (key == 0)
                        {
                            diffKey = 0;
                        }
                        else
                        {
                            diffKey = key - tempkey - 1;
                        }
                        if (tempValue == 0 && value != 0 && indexFlag  == 0) //判断第一位元素
                        {
                            diffValue = value - tempValue;
                        }
                        else if (value == 0)
                        {
                            diffValue = 0;
                        }
                        else
                        {
                            diffValue = value - tempValue - 1;
                        }
                        Console.WriteLine("diffKey:" + diffKey + ", diffValue:" + diffValue);
                        if (diffKey > diffValue)
                        {
                            D += diffKey - diffValue;
                            S += diffValue;
                        }
                        else if (diffKey == diffValue)
                        {
                            S += diffValue;
                        }
                        else
                        {
                            I += diffValue - diffKey;
                            S += diffKey;
                        }
                        tempkey = key;
                        tempValue = value;
                        indexFlag ++;
                    }
                    Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
                    diffKey = original.Length - tempkey - 1;
                    diffValue = test.Length - tempValue - 1;
                    Console.WriteLine("diffKey:" + diffKey + ",diffValue:" + diffValue);
                    if (diffKey > diffValue)
                    {
                        D += diffKey - diffValue;
                        S += diffValue;
                    }
                    else if (diffKey == diffValue)
                    {
                        S += diffValue;
                    }
                    else
                    {
                        I += diffValue - diffKey;
                        S += diffKey;
                    }
                    H = sortedIndex.Count;
                    N = test.Length;
                    float[] result = new float[7];
                    acc = (float)(H - I) / (float)N;
                    corr = H / (float)N;
                    result[0] = acc;
                    result[1] = corr;
                    result[2] = H;
                    result[3] = N;
                    result[4] = D;
                    result[5] = S;
                    result[6] = I;
                    Console.WriteLine("acc" + acc + ",corr: " + corr);
                    resultList.Add(result);
                }
            }
            catch (Exception e)
            {
                Console.WriteLine("error accur:" + e.ToString());
                return null;
            }
            return resultList;
        }

        internal static object GetInstance()
        {
            throw new NotImplementedException();
        }
    }
}

你可能感兴趣的:(语音转写正确率计算)