在基于自然语言的人机交互系统中,通常会定义一些语义模板来训练NLU (自然语言理解)模型,比如下面的模板可以支持用户通过语音控制机器播放音乐:
其中"@{singer}"是个参数, 代表歌手,比如第一个模板可以匹配这样的用户query:“放几首刘德华的歌”。可以看到,同样是放歌,有很多种不同但相似的说法,但把他们一条一条单独列出来,编辑的成本会比较高, 而且会漏掉一些说法, 不严谨。实际上,上面的5个模板,可以用下面的语义模板表达式来表示:
其中包含中括号("[]") 、尖括号("<>") 和竖线("|")三种元素:
那么,给一个上述的语义模板表达式和用户的query,你能判断用户的quey是否能匹配这个表达式吗?
时间限制: 3S (C/C++以外的语言为: 5 S) 内存限制: 128M (C/C++以外的语言为: 640 M)
输入:
输入数据包含两行,
第一行,上述格式的语义模板表达式
第二行,用户的自然语言指令(即:用户query)
输出:
当前query是否匹配当前语义模板表达式。匹配,则输出1,否则输出0.
输入范例:
<[播]放|来>[一|几]<首|曲|个>@{singer}的<歌[曲]|[流行]音乐>
来几首@{singer}的流行歌曲
输出范例:
0
首先说明,这里提供的是博主本人的解题方法。这道题考试时给的时间是半小时,但是博主菜鸡,前后写了差不多两个小时才写完(考完之后继续写的)。所以这里给出的解法可能并不完善,也不能保证正确,因为没法跑测试用例了>.<
言归正传:按照给定的语义模板,列出所有可能的表达组合,形成query指令集。当用户输入query时,与集合内的指令逐条比较,如果存在则输出1,如果不存在则输出0。
所以关键就在列出所有可能的表达组合。采用的思路是递归解析,即我们需要实现一个函数:
vector
则对于模板中的符号,分别做如下处理:
analysePattern(A)
的结果集并上analysePattern(B)
的结果集;上面的2、3两条追加处理详见函数vector
。
完成这些基本的处理后,还需要注意一些细节:形如"[歌]曲"的模板中,虽然"曲"没有任何符号指定,但实际上这个模板等同于"[歌]<曲>",同样的还有题目输入样例“<[播]放|来>[一|几]<首|曲|个>@{singer}的<歌[曲]|[流行]音乐>”中的"@{singer}的"实际上等同于"<@{singer}的>"。
#include
#include
#include
#include
using namespace std;
vector contentMerge(vector v1, vector v2) {
// v1中的每一个字符串后追加v2中字符的内容
if (v1.size() < 1) {
return v2;
} else {
vector results = vector();
for (int i = 0; i < v1.size(); i++) {
string part1 = v1[i];
for (int j = 0; j < v2.size(); j++) {
results.push_back(part1 + v2[j]);
}
}
return results;
}
}
vector analysePattern(string pattern, int rootDepth, bool addEmpty) {
vector results = vector();
if (addEmpty) {//是否可选,如果可选,可以为空字符串
results.push_back("");
}
vector ors = vector();//存储“或”符号的位置
size_t length = pattern.length();
int depth = rootDepth;
//check '|'
for (int i = 0; i < length; i++) {
char c = pattern[i];
if (c == '<' || c == '[') {
depth++;
} else if (c == '>' || c == ']') {
depth--;
} else if (c == '|' && depth == rootDepth) {//与根同层的或符号,直接拆分为子pattern处理
ors.push_back(i);
}
}
int start = 0;
if (ors.size() > 0) {//包含与根同层的或符号,直接拆分处理
for (int i = 0; i < ors.size(); i++) {
int end = ors[i];
vector part = analysePattern(pattern.substr(start, end - start), rootDepth, false);
start = end + 1;
results.insert(results.end(), part.begin(), part.end());
}
//add last
size_t end = pattern.length();
vector part = analysePattern(pattern.substr(start, end - start), rootDepth, false);
results.insert(results.end(), part.begin(), part.end());
} else {
int depth = rootDepth;
stack stacks = stack();
bool hasPattern = false;//是否包含标记符 | <> []
int necessary_Start = 0;//标记形如"歌[曲]"或"[歌]曲"的情况,他们等同于"<歌>[曲]"或"[歌]<曲>"
for (int i = 0; i < length; i++) {
char c = pattern[i];
if (c == '<' || c == '[') {
if (depth == rootDepth) {
if (i > necessary_Start) {
vector part = analysePattern(pattern.substr(necessary_Start, i - necessary_Start), depth, false);
results = contentMerge(results, part);
}
}
hasPattern = true;
depth++;
stacks.push(i);
} else if (c == '>') {
depth--;
if (depth == rootDepth) {
necessary_Start = i + 1;
}
int index = stacks.top();
stacks.pop();
if (depth == rootDepth) {
vector part = analysePattern(pattern.substr(index + 1, i - index - 1), depth + 1, false);
results = contentMerge(results, part);
}
} else if (c == ']') {
depth--;
if (depth == rootDepth) {
necessary_Start = i + 1;
}
int index = stacks.top();
stacks.pop();
if (depth == rootDepth) {
vector part = analysePattern(pattern.substr(index + 1, i - index - 1), depth + 1, true);
results = contentMerge(results, part);
}
}
}
if (!hasPattern) {
results.push_back(pattern);
} else {
if (necessary_Start < pattern.length()) {
vector part = analysePattern(pattern.substr(necessary_Start, pattern.length() - necessary_Start), depth, false);
results = contentMerge(results, part);
}
}
}
return results;
}
bool isMatch(vector set, string query) {
for (int i = 0; i < set.size(); i++) {
if (query.compare(set[i]) == 0) {
return true;
}
}
return false;
}
int main() {
//vector allExpress = analysePattern("<[播]放|来>[一|几]<首|曲|个>@{singer}的<歌[曲]|[流行]音乐>", 0, false);
vector allExpress = analysePattern("<[播]放|来>[一|几]<首|曲|个>@{singer}的<歌[曲]|[流行]音乐>", 0, false);
/*
cout << "Print All Expression" << endl;
for (int i = 0; i < allExpress.size(); i++) {
cout << allExpress[i] << endl;
}
*/
cout << (int)isMatch(allExpress, "来几首@{singer}的流行歌曲") << endl;
return 0;
}