double-array trie算法实现

以下是公司同事编程大赛代码,后续会把详细分析加上,优化代码结构


#include <stdint.h>
#include <vector>
#include <set>
#include <string>
#include <sys/types.h>
#include <stdlib.h>
#include <map>
#include <iostream>

using namespace std;

//typedef int int32_t;
//typedef unsigned int uint32_t;
//typedef unsigned short uint16_t;
//typedef short int16_t;

namespace utf8
{

#define ENCODEONECHAR(ch) \
(uint16_t)(ch)

#define ENCODETWOCHAR(ch1, ch2) \
(uint16_t)(((((unsigned char)(ch1))&0x1F)<<6)+(((unsigned char)(ch2))&0x3F))

#define ENCODETHREECHAR(ch1, ch2, ch3) \
(uint16_t)(((((unsigned char)(ch1))&0x0F)<<12)+((((unsigned char)(ch2))&0x3F)<<6) + (((unsigned char)(ch3))&0x3F))

size_t normalizeString(const char *pStr, uint16_t *pNormalCodes)
{
uint16_t *pValue = pNormalCodes;
for(; *pStr != '\0'; )
{
//U-00000000 - U-0000007F: 0xxxxxxx 
if ((((unsigned char)(*pStr)) & 0x80) == 0) //如果是中文
{
/*if((unsigned char)(*pStr) >= 0x41 && (unsigned char)(*pStr) <= 0x5a) //小写字面转为大写字母
*pValue++ = (uint16_t)(((unsigned char)(*pStr) | 0x20));
else/**/
 *pValue++ = (uint16_t)((unsigned char)(*pStr));
pStr++;
} 
//U-00000080 - U-000007FF: 110xxxxx 10xxxxxx
else if((((unsigned char)(*pStr)) & 0xE0) == 0xC0)
{
            if(*(pStr+1) == '\0') return 0;
*pValue++ = (uint16_t)(((((unsigned char)(*pStr))&0x1F)<<6)+(((unsigned char)(*(pStr+1)))&0x3F));
pStr += 2;
}
//U-00000800 - U-0000FFFF: 1110xxxx 10xxxxxx 10xxxxxx 
else if((((unsigned char)(*pStr)) & 0xF0) == 0xE0)
{
             if(*(pStr+1) == '\0' || *(pStr+2) == '\0') return 0;
*pValue++ = (uint16_t)(((((unsigned char)(*pStr))&0x0F)<<12)+((((unsigned char)(*(pStr+1)))&0x3F)<<6) + (((unsigned char)(*(pStr+2)))&0x3F));
pStr += 3;
}
else
{
return 0;
}
}
return pValue-pNormalCodes;
}

size_t normalizeString(const char *pStr, std::vector<uint16_t> &pNormalCodes)
{
for(; *pStr != '\0'; )
{
//U-00000000 - U-0000007F: 0xxxxxxx 
if ((((unsigned char)(*pStr)) & 0x80) == 0) //如果是中文
{
/*if((unsigned char)(*pStr) >= 0x41 && (unsigned char)(*pStr) <= 0x5a) //小写字面转为大写字母
pNormalCodes.push_back((uint16_t)(((unsigned char)(*pStr) | 0x20)));
else/**/
 pNormalCodes.push_back((uint16_t)((unsigned char)(*pStr)));
pStr++;
} 
//U-00000080 - U-000007FF: 110xxxxx 10xxxxxx
else if((((unsigned char)(*pStr)) & 0xE0) == 0xC0)
{
            if(*(pStr+1) == '\0') return 0;
pNormalCodes.push_back((uint16_t)(((((unsigned char)(*pStr))&0x1F)<<6)+(((unsigned char)(*(pStr+1)))&0x3F)));
pStr += 2;
}
//U-00000800 - U-0000FFFF: 1110xxxx 10xxxxxx 10xxxxxx 
else if((((unsigned char)(*pStr)) & 0xF0) == 0xE0)
{
            if(*(pStr+1) == '\0' || *(pStr+2) == '\0') return 0;
pNormalCodes.push_back((uint16_t)(((((unsigned char)(*pStr))&0x0F)<<12)+((((unsigned char)(*(pStr+1)))&0x3F)<<6) + (((unsigned char)(*(pStr+2)))&0x3F)));
pStr += 3;
}
else
{
return 0;
}
}
return pNormalCodes.size();
}


//比如str="0xa1 0xb2 0xc3 0xd4 0xe5 0xa6", 则pNormalCodes="1000,1002",那么pPosMatches[0]=3,pPosMatches[1]=6
size_t encodeString(const char *pStr, uint16_t *pNormalCodes, size_t *pPosMatches)
{
const char *pStrOri = pStr;
uint16_t *pValue = pNormalCodes;
for(; *pStr != '\0'; )
{
//U-00000000 - U-0000007F: 0xxxxxxx 
if ((((unsigned char)(*pStr)) & 0x80) == 0) //如果是中文
{
/*if((unsigned char)(*pStr) >= 0x41 && (unsigned char)(*pStr) <= 0x5a) //小写字面转为大写字母
*pValue++ = (uint16_t)(((unsigned char)(*pStr) | 0x20));
else/**/
 *pValue++ = (uint16_t)((unsigned char)(*pStr));
pStr++;
*pPosMatches++ = (pStr-pStrOri);
} 
//U-00000080 - U-000007FF: 110xxxxx 10xxxxxx
else if((((unsigned char)(*pStr)) & 0xE0) == 0xC0)
{
            if(*(pStr+1) == '\0') return 0;
*pValue++ = (uint16_t)(((((unsigned char)(*pStr))&0x1F)<<6)+(((unsigned char)(*(pStr+1)))&0x3F));
pStr += 2;
*pPosMatches++ = (pStr-pStrOri);
}
//U-00000800 - U-0000FFFF: 1110xxxx 10xxxxxx 10xxxxxx 
else if((((unsigned char)(*pStr)) & 0xF0) == 0xE0)
{
             if(*(pStr+1) == '\0' || *(pStr+2) == '\0') return 0;
*pValue++ = (uint16_t)(((((unsigned char)(*pStr))&0x0F)<<12)+((((unsigned char)(*(pStr+1)))&0x3F)<<6) + (((unsigned char)(*(pStr+2)))&0x3F));
pStr += 3;
*pPosMatches++ = (pStr-pStrOri);
}
else
{
return 0;
}
}
return pValue-pNormalCodes;
}


size_t encodeNext(const char *&pStr, uint16_t &code)
{
if(*pStr == '\0') return 0;
//U-00000000 - U-0000007F: 0xxxxxxx 
if ((((unsigned char)(*pStr)) & 0x80) == 0) //如果是中文
{
/*if((unsigned char)(*pStr) >= 0x41 && (unsigned char)(*pStr) <= 0x5a) //小写字面转为大写字母
code = (uint16_t)(((unsigned char)(*pStr) | 0x20));
else/**/
 code = (uint16_t)((unsigned char)(*pStr));
pStr++;
return 1;
} 
//U-00000080 - U-000007FF: 110xxxxx 10xxxxxx
else if((((unsigned char)(*pStr)) & 0xE0) == 0xC0)
{
        if(*(pStr+1) == '\0') return 0;
code = (uint16_t)(((((unsigned char)(*pStr))&0x1F)<<6)+(((unsigned char)(*(pStr+1)))&0x3F));
pStr += 2;
return 2;
}
//U-00000800 - U-0000FFFF: 1110xxxx 10xxxxxx 10xxxxxx 
else if((((unsigned char)(*pStr)) & 0xF0) == 0xE0)
{
        if(*(pStr+1) == '\0' || *(pStr+2) == '\0') return 0;
code = (uint16_t)(((((unsigned char)(*pStr))&0x0F)<<12)+((((unsigned char)(*(pStr+1)))&0x3F)<<6) + (((unsigned char)(*(pStr+2)))&0x3F));
pStr += 3;
return 3;
}
return 0;
}


size_t getWordNum(const char *pStr)
{
size_t wordNum = 0;
for(; *pStr != '\0'; )
{
//U-00000000 - U-0000007F: 0xxxxxxx 
if ((((unsigned char)(*pStr)) & 0x80) == 0) //如果是中文
{
pStr++;
} 
//U-00000080 - U-000007FF: 110xxxxx 10xxxxxx
else if((((unsigned char)(*pStr)) & 0xE0) == 0xC0)
{
if(*(pStr+1) == '\0') return 0;
pStr += 2;
            
}
//U-00000800 - U-0000FFFF: 1110xxxx 10xxxxxx 10xxxxxx 
else if((((unsigned char)(*pStr)) & 0xF0) == 0xE0)
{
            if(*(pStr+1) == '\0' || *(pStr+2) == '\0') return 0;
pStr += 3;
}
else
{
return 0;
}
wordNum++;
}
return wordNum;
}


}


/**
每个单元的值
*/
struct SBC 
{
int32_t m_base;
int32_t m_check;
SBC();
~SBC();
};


/**
内部的写单元
*/
struct SInnerDATWUnit 
{
std::vector<uint16_t> m_keys; 
size_t m_index; //关键词在CBoubleTrieBaseWriter::m_keys中的位置
size_t m_startPos; //要处理的字在关键词this->m_keys中位置
std::set<size_t> m_childrenIndex; //孩子节点的位置。必须从小到大排序 
std::set<size_t> m_endIndex; //以该点为结束的记录的index 

SInnerDATWUnit(size_t index, size_t startPos, std::vector<uint16_t>& keys) 
{
this->m_index = index;
this->m_startPos = startPos;
m_keys.reserve(m_startPos+1);
for(size_t i=0; i<=m_startPos; i++) 
{
m_keys.push_back(keys[i]);
}
}

~SInnerDATWUnit() 
{
}

void print()
{
fprintf(stdout, "(%d)keys=", (int)m_index);
for(size_t i=0; i<=m_startPos; i++)
{
fprintf(stdout, "%d,", (int)m_keys[i]);
}
fprintf(stdout, "\nchildren=");
for(std::set<size_t>::iterator iter = m_childrenIndex.begin(); iter !=m_childrenIndex.end(); iter++)
{
fprintf(stdout, "%d,", (int)*iter);
}
fprintf(stdout, "\nendIndex=");
for(std::set<size_t>::iterator iter = m_endIndex.begin(); iter !=m_endIndex.end(); iter++)
{
fprintf(stdout, "%d,", (int)*iter);
}
fprintf(stdout, "\n\n");
}

};


struct writesort:public std::binary_function<const SInnerDATWUnit*, const SInnerDATWUnit*, bool> 
{
bool operator()(const SInnerDATWUnit* unit1,const SInnerDATWUnit* unit2) const
{
if(unit1->m_childrenIndex.size() == unit2->m_childrenIndex.size()) {
if(unit1->m_startPos == unit2->m_startPos) 
return unit1->m_index < unit2->m_index;
return unit1->m_startPos < unit2->m_startPos;
}
return unit1->m_childrenIndex.size() > unit2->m_childrenIndex.size();
}
};


struct unitsort :public std::binary_function<const SInnerDATWUnit*, const SInnerDATWUnit*, bool> 
{
bool operator()(const SInnerDATWUnit* unit1, const SInnerDATWUnit* unit2) const
{
size_t minPos =unit1->m_startPos > unit2->m_startPos? unit2->m_startPos:unit1->m_startPos;
for(size_t i=0; i<=minPos; i++)
{
if(unit1->m_keys[i] != unit2->m_keys[i]) 
return unit1->m_keys[i] < unit2->m_keys[i];
}
return unit1->m_startPos < unit2->m_startPos;
}
};




struct keysort :public std::binary_function<const std::vector<uint16_t>*, const std::vector<uint16_t>*, bool> 
{
bool operator()(const std::vector<uint16_t>* unit1, const std::vector<uint16_t>* unit2) const
{
size_t minLen =unit1->size() > unit2->size()? unit2->size():unit1->size();
for(size_t i=0; i<minLen; i++) {
if(unit1->at(i) != unit2->at(i)) {
return unit1->at(i) < unit2->at(i);
}
}
return unit1->size() < unit2->size();
}
};


class CBoubleTrieBaseWriter 
{
public:
/*reversedNum表示第一个关键字必须保留的空间*/
CBoubleTrieBaseWriter(int32_t reversedNum=65536);
~CBoubleTrieBaseWriter();

public:
/*增加一个关键字*/
bool addKey(const char *key, bool shouleReverse = false);
/* 生成所有单元 */
bool buildUnits();
/*写mmap文件*/
bool write(char *mmapPtr);
/*返回实际需要内存大小byte字节数*/
int32_t getDATSize();
/*得到槽的个数*/
int32_t getSlotNum();
/*得到一个词的位置,用于存储其他信息*/
int32_t getUnitPos(const char *key, bool shouleReverse = false);


private:
/* 得到符合条件的一组位置,返回base值 */
inline int32_t getConsistentBasePos(SInnerDATWUnit *unit);
/* 生成一个单元 */
inline bool buildUnit(SInnerDATWUnit *unit);
//对加入的关键字进行排序和扩展
inline bool preBuild(int32_t level);
/*对key进行扩展,生成各个SInnerDATWUnit,保存到m_unists中*/
inline void extendUnits(size_t pos, int32_t level);
/*返回第一个不小于num的2的次方数*/
inline int eup2power(int num);
inline bool allocateMemery(int32_t needCap);
inline void _print();

public:
std::set<SInnerDATWUnit *,unitsort> m_unists; //要建立的单元
std::set<SInnerDATWUnit *,writesort> m_writeUnits;//按照建立顺序排序好的对象
int32_t m_allocBcNum; //已经分配的bc单元数
SBC *m_pBC;
private:
  std::vector<std::vector<uint16_t> > m_keys; //要建索引的关键字
  uint32_t m_buildIncTime; //每次增加的分配空间的倍数
  int32_t m_nextFreePos; //下一个空闲的bc位置
  int32_t m_usedMaxBCPos; //已经占用的最大的位置
  int32_t m_reversedNum; //给第一个使用的空间
  int32_t m_maxLevel;     //最大的层级
};




SBC::SBC() 
{ 
m_base = 0;
m_check = -2; //m_check=-2表示空闲,m_check=-1表示首字
}


SBC::~SBC() 
{
}


CBoubleTrieBaseWriter::CBoubleTrieBaseWriter(int32_t reversedNum) 
{
m_buildIncTime = 2;
m_nextFreePos = reversedNum+1;//第一个可用空间从所有单字编码+1开始
m_reversedNum = reversedNum;
m_maxLevel = 0;
m_usedMaxBCPos = 0;
m_allocBcNum = eup2power(m_nextFreePos)*4;
m_pBC = new SBC[m_allocBcNum]; //先分配m_buildInitNum个bc空间
//放入一个空的值,为了使第一个词的索引小于0
vector<uint16_t> noKey;
noKey.reserve(1);
m_keys.push_back(noKey);
}


CBoubleTrieBaseWriter::~CBoubleTrieBaseWriter() 
{
set<SInnerDATWUnit *, unitsort>::iterator iter;
for(iter=m_unists.begin(); iter!=m_unists.end(); iter++) 
delete *iter;
//delete []m_pBC;
}


bool CBoubleTrieBaseWriter::addKey(const char *key, bool shouleReverse) 
{
if(key == NULL || *key == '\0') return false;
//先对key进行编码
vector<uint16_t> codeValues;
codeValues.reserve(strlen(key));
if(utf8::normalizeString(key, codeValues) == 0) 
return false;
if(m_maxLevel < codeValues.size()) m_maxLevel = codeValues.size();
if(!shouleReverse)
m_keys.push_back(codeValues);
else
{
vector<uint16_t> tmpcodeValues;
tmpcodeValues.reserve(codeValues.size());
for(size_t i=0; i<codeValues.size(); i++)
tmpcodeValues.push_back(codeValues[codeValues.size()-1-i]);
m_keys.push_back(tmpcodeValues);
}
return true;
}


bool CBoubleTrieBaseWriter::preBuild(int32_t level) 
{
//默认前面加个值,表示为不为空
if(m_keys.size() <= 1) 
return false; 
//扩展
for(size_t i=1; i<m_keys.size(); i++) 
extendUnits(i, level);


return true;
}


int32_t CBoubleTrieBaseWriter::getUnitPos(const char *key, bool shouleReverse)
{
if(key == NULL || *key == '\0') return -1;
//先对key进行编码
vector<uint16_t> codeValues;
codeValues.reserve(strlen(key));
if(utf8::normalizeString(key, codeValues) == 0) 
return -1;
int32_t curIndex = -1;
if(!shouleReverse)
{
curIndex = codeValues[0]; 
for(size_t i=1; i < codeValues.size(); i++) 
curIndex = abs(m_pBC[curIndex].m_base)+codeValues[i];
}
else
{
curIndex = codeValues[codeValues.size()-1]; 
for(size_t i=1; i < codeValues.size(); i++) 
curIndex = abs(m_pBC[curIndex].m_base)+codeValues[codeValues.size()-1-i];
}
return curIndex;
}


int32_t CBoubleTrieBaseWriter::getDATSize() 
{
if(m_usedMaxBCPos == 0) return sizeof(int32_t);
return (m_usedMaxBCPos+1)*sizeof(int32_t)*2 + sizeof(int32_t);
}


int32_t CBoubleTrieBaseWriter::getSlotNum()
{
if(m_usedMaxBCPos == 0) return 0;
return m_usedMaxBCPos+1;
}


bool CBoubleTrieBaseWriter::write(char *mmapPtr) 
{
if(m_usedMaxBCPos > 0)
{
//先放槽个数
*(int32_t *)mmapPtr = m_usedMaxBCPos+1;
mmapPtr += sizeof(int32_t);
for(int32_t i=0; i<=m_usedMaxBCPos; i++) 
{
*(int32_t *)mmapPtr = m_pBC[i].m_base;
mmapPtr += sizeof(int32_t);
*(int32_t *)mmapPtr = m_pBC[i].m_check;
mmapPtr += sizeof(int32_t);
}
}
else
{
*(int32_t *)mmapPtr = 0;
mmapPtr += sizeof(int32_t);
}
return true;
}


bool CBoubleTrieBaseWriter::buildUnits() 
{
if(m_keys.size() <= 1) return false;
for(int32_t leveli=0; leveli<=m_maxLevel; leveli++)
{
//先预处理
if(!preBuild(leveli)) continue;
if(m_unists.size() == 0) continue;
//按照生成索引的顺序排序
for(set<SInnerDATWUnit *,unitsort>::iterator iter=m_unists.begin(); iter!=m_unists.end(); iter++) 
m_writeUnits.insert(*iter);


//建立索引
for(set<SInnerDATWUnit *,writesort>::iterator iter2=m_writeUnits.begin(); iter2!=m_writeUnits.end(); iter2++) {
//(*iter2)->print();
if(!buildUnit(*iter2))
return false;


//_print();
}
//_print();
//释放资源
m_writeUnits.clear();
for(set<SInnerDATWUnit *, unitsort>::iterator iter=m_unists.begin(); iter!=m_unists.end(); iter++) delete *iter;
m_unists.clear();
}
return true;
}


bool CBoubleTrieBaseWriter::buildUnit( SInnerDATWUnit *unit) 
{
//定位到unit的m_startPos所在的绝对偏移地址
int32_t curIndex = unit->m_keys[0]; 
m_pBC[curIndex].m_check = -1;
for(size_t i=1; i<=unit->m_startPos; i++)
curIndex = abs(m_pBC[curIndex].m_base)+unit->m_keys[i];


if(m_usedMaxBCPos < curIndex) //设置最大的值
m_usedMaxBCPos = curIndex;


//设置m_base值
if(unit->m_endIndex.size() != 0) //如果是个关键字,则m_base为负数
{
if (m_pBC[curIndex].m_base > 0) 
m_pBC[curIndex].m_base *= -1;
else if(m_pBC[curIndex].m_base == 0)
m_pBC[curIndex].m_base = -1 * (*unit->m_endIndex.begin());
} 
//设置后续节点的m_check值和当前节点的m_base值
if(unit->m_childrenIndex.size() > 0) 
{
//获取其子节点的base值
int32_t base = getConsistentBasePos(unit);
//设置该unit的m_base值
if(m_pBC[curIndex].m_base < 0) 
m_pBC[curIndex].m_base = -1 * base;
else 
m_pBC[curIndex].m_base = base;


for(set<size_t>::iterator iter = unit->m_childrenIndex.begin(); iter!=unit->m_childrenIndex.end(); iter++) 
{
int32_t addt = base + m_keys[*iter][unit->m_startPos+1];
m_pBC[addt].m_check = curIndex; //设置check值
if(m_usedMaxBCPos < addt) {//设置最大的值
m_usedMaxBCPos = addt;
}
}
}
return true;
}


bool CBoubleTrieBaseWriter::allocateMemery(int32_t needCap) 
{
int32_t allocateMemeoryNum = m_allocBcNum;
while (needCap > allocateMemeoryNum) 
allocateMemeoryNum *= m_buildIncTime;


if (allocateMemeoryNum > m_allocBcNum) 
{ //重新分配空间
SBC *tmp = new SBC[allocateMemeoryNum];
memcpy(tmp, m_pBC, m_allocBcNum*sizeof(SBC));
delete []m_pBC;
m_pBC = tmp;
m_allocBcNum = allocateMemeoryNum;
}
return true;
}


//unit必须为unit->m_childrenIndex.size()>0
int32_t CBoubleTrieBaseWriter::getConsistentBasePos( SInnerDATWUnit *unit) 
{
//移动到第一个可以存放空间的位置
while(m_nextFreePos<m_allocBcNum && m_pBC[m_nextFreePos].m_check != -2) m_nextFreePos++;
if (m_nextFreePos >= m_allocBcNum) { 
//重新分配空间
allocateMemery(m_nextFreePos+1);
}
//开始找位置
bool isFound = false;
int32_t nextPos = m_nextFreePos, basePos;
while(!isFound) {
isFound = true;
basePos = nextPos-m_keys[*(unit->m_childrenIndex.begin())][unit->m_startPos+1];
if(basePos < 0) {//base值不能小于0,因为小于0表示该部分为词
nextPos += abs(basePos);
isFound = false;
break;
}
for(set<size_t>::iterator iter = unit->m_childrenIndex.begin(); iter!=unit->m_childrenIndex.end(); iter++) {
while(basePos + m_keys[*iter][unit->m_startPos+1] >= m_allocBcNum) { 
//重新分配空间
allocateMemery(basePos + m_keys[*iter][unit->m_startPos+1]+1);
}
if(m_pBC[basePos + m_keys[*iter][unit->m_startPos+1]].m_check != -2) { //如果不空闲
//寻找下个空闲的bc位置
nextPos++;
while(nextPos < m_allocBcNum && m_pBC[nextPos].m_check != -2)  nextPos++;
if (nextPos >= m_allocBcNum) { 
//重新分配空间
allocateMemery(nextPos+1);
}
isFound = false;
break;
} else if(basePos + m_keys[*iter][unit->m_startPos+1] < m_reversedNum+1){ //不能占据第一个编码的位置
nextPos += m_reversedNum+1 - basePos - m_keys[*iter][unit->m_startPos+1];
isFound = false;
break;
}
}
}


return basePos;
}


void CBoubleTrieBaseWriter::extendUnits(size_t pos, int32_t level) 
{
size_t valueNum = m_keys[pos].size();
if(valueNum < level+1) return;
//for (size_t i=0; i<valueNum; i++) 
//{
SInnerDATWUnit *pUnit= new SInnerDATWUnit(pos, level, m_keys[pos]);
set<SInnerDATWUnit *,unitsort>::iterator iter = m_unists.find(pUnit);
if(iter == m_unists.end()) 
{ 
m_unists.insert(pUnit); //如果刚加入
if(level != valueNum-1) //如果不是最后一个
pUnit->m_childrenIndex.insert(pos);
else 
pUnit->m_endIndex.insert(pos);
}
else 
{
if (level != valueNum-1)
(*iter)->m_childrenIndex.insert(pos);
else
(*iter)->m_endIndex.insert(pos);
delete pUnit;
}
//} 
}


int CBoubleTrieBaseWriter::eup2power(int num) 
{
int value =1;
while(value < num) value *= 2;
return value;
}


void CBoubleTrieBaseWriter::_print()
{
for(int32_t i=0; i<=m_usedMaxBCPos; i++) 
{
if(m_pBC[i].m_base != 0)
{
fprintf(stdout, "pos=%d, base=%d, check=%d\n", i, m_pBC[i].m_base, m_pBC[i].m_check);
}
}
}


vector<string> split(const string& src, char det){
vector<string> ret;
string::size_type pos = 0;
string::size_type use;
use = src.find(det);
while(use != string::npos){
ret.push_back(src.substr(pos, use - pos));
pos = use + 1;
use = src.find(det, pos);
}
ret.push_back(src.substr(pos));
return ret;
}


struct wordnumsort :public std::binary_function<const string, const string, bool> 
{
bool operator()(const string unit1,const string unit2) const
{
vector<string> unit1Words = split(unit1, ',');
vector<string> unit2Words = split(unit2, ',');
if (unit1Words.size() == unit2Words.size())
{
return strcmp(unit1.c_str(),unit2.c_str())<0;
}
else
{
return unit1Words.size() < unit2Words.size();
}
}
};


struct Stringsort :public std::binary_function<const string, const string, bool> 
{
bool operator()(const string unit1,const string unit2) const
{
return strcmp(unit1.c_str(), unit2.c_str())< 0;
}
};


bool Contain(const vector<string>& src,const vector<string>& dec)
{
for (uint32_t j=0; j<dec.size(); ++j)
{
bool bExist = true;
vector<string> Words = split(dec[j], ',');


for (uint32_t k=0; k<Words.size(); k++)
{
bool b = false;
for (uint32_t i=0; i<src.size(); i++)
{
if (strstr(src[i].c_str(),Words[k].c_str()))
{
b = true;
break;
}
}


if (!b)
{
bExist = false;
break; 
}
}


if (bExist)
{
return true;
}
}


return false;
}


void sortdic(const vector<string> &dictsrc, vector<string> &dictdec)
{
map<const string, char, wordnumsort> sortMap;
for(uint32_t i=0; i<dictsrc.size(); ++i)
{
sortMap.insert(make_pair(dictsrc[i],0));
}

map<string,int> decMap;
map<const string, vector<string>, Stringsort> temMap;

for (map<const string, char, wordnumsort>::iterator iter = sortMap.begin(); iter!=sortMap.end(); iter++)
{
bool bExist = false;
//判断该规则是否重复
vector<string> Words = split(iter->first, ',');
for (uint32_t j=0; j<Words.size(); ++j)
{
map<const string, vector<string>,Stringsort>::iterator it = temMap.find(Words[j]);
if (temMap.end() == it)
{
}
else
{
bExist = Contain(Words,it->second);
if (bExist)
{
break;
}
}
}


if (!bExist)
{


for (uint32_t j=0; j<Words.size(); ++j)
{
map<const string, vector<string>, Stringsort>::iterator it = temMap.find(Words[j]);
if (temMap.end() == it)
{
vector<string> vec;
vec.push_back(iter->first);
temMap.insert(make_pair(Words[j],vec));
}
else
{
it->second.push_back(iter->first);
}
}


//dictdec.push_back(iter->first);
decMap.insert(make_pair(iter->first,0));
}


}


for(uint32_t i=0; i<dictsrc.size(); ++i)//保持原词典顺序
{
if (decMap.end() != decMap.find(dictsrc[i]))
{
dictdec.push_back(dictsrc[i]);
}
}/**/
}


#define ABS(a) ((a)>0?(a):-(a))
#define GBC(index, base, check)  {base = m_pBC[index].m_base; check = m_pBC[index].m_check;}

class CMyChecker
{
public:
// 0 => use string
// 1=> use wstring
const static int USE_WSTRING = 1;
public:
// you must rewrite this functions
// (1)initing
// IN dict : dict vector
void Init(const vector<string> &dict);
// (2)checking
// IN tiezi : tiezi string
// OUT : 1=>hit dict
//       0=>miss dict
int Check(const string &tiezi);
inline int Check(const wstring &tiezi);
private:
// you can add your function here
vector<string> Split(const string& src, char det);
// ...
private:
// you can add your data-structure here
class CMyData
{
public:
string forbidenWord;
uint32_t hitCount;
};
vector<CMyData> m_myDict;
// ...

private:
SBC* m_pBC; 
uint32_t m_nDatLen;//DAT长度

uint32_t* m_pMatchPos;//每个词对应的规则位置
uint32_t* m_pMatchInfo;//每个词的对应的规则的位置信息

//长度为规则的数目
uint32_t* m_pMatch;//匹配器

uint32_t* m_pState;//匹配器状态
uint32_t m_nCurState;//当前匹配器应为的状态
//check需要的变量
};

typedef struct wordinfo
{
uint32_t pos;//规则的位置
uint32_t cur;//词在规则的位置
};

// you must rewrite this function
// initing
// IN dict : dict vector
void CMyChecker::Init(const vector<string> &dict1){


//词典过滤
vector<string> dict;
sortdic(dict1,dict);

//fprintf(stdout, "dict1src:%d,dictdes:%d\n",dict1.size(),dict.size());


CBoubleTrieBaseWriter writer;
m_pMatch = new uint32_t[dict.size()];
memset(m_pMatch,0,dict.size()*sizeof(uint32_t));
//m_pMatchRight = new uint32_t[dict.size()];
uint16_t* pMatch = (uint16_t*)m_pMatch;


m_pState = new uint32_t[dict.size()];
memset(m_pState,0,dict.size()*sizeof(uint32_t));
m_nCurState = 0;


//创建词典DAT

for(uint32_t i=0; i<dict.size(); ++i)
{

vector<string> multiWords = Split(dict[i], ',');
uint32_t Count = multiWords.size();
//存入每条规则正确的匹配信息
uint16_t nLen = (0xFFFF<<(16-Count));
nLen = nLen>>(16-Count);
//每条规则正确的bit位
pMatch[i*2+1]=nLen;
//m_pMatchRight[i] = nLen;
for (int j = 0; j < Count ; j++ )
{
writer.addKey(multiWords[j].c_str());
}
}


if(!writer.buildUnits()) 
{
return;
} 


m_pBC = writer.m_pBC;
fprintf(stdout, "dat successed\n");
m_nDatLen = writer.getSlotNum();

int nCount = 0;//规则的数量
map<string,vector<wordinfo> > wordMap;//存入词的位置
for(uint32_t i=0; i<dict.size(); ++i)
{
vector<string> multiWords = Split(dict[i], ',');
uint32_t Count = multiWords.size();
for (int j = 0; j < Count ; j++ )
{
wordinfo info;
info.pos = i;
info.cur = 1<<j;
wordMap[multiWords[j]].push_back(info);
nCount+=2;
}
}
nCount+=wordMap.size();


fprintf(stdout, "wordMap size:%d\n",wordMap.size());



//创建匹配器位置
m_pMatchPos = new uint32_t[m_nDatLen];
uint32_t nPos = 0;
m_pMatchInfo = new uint32_t[nCount];


for (map< string,vector<wordinfo> >::iterator it = wordMap.begin(); it!=wordMap.end(); it++)
{
uint32_t pos = (uint32_t)writer.getUnitPos(it->first.c_str());
m_pMatchPos[pos] = nPos;
m_pMatchInfo[nPos] = it->second.size();
nPos++;
for (uint32_t i=0; i<it->second.size(); ++i)
{
m_pMatchInfo[nPos]=((it->second)[i]).pos;
nPos++;
m_pMatchInfo[nPos]=((it->second)[i]).cur;
nPos++;
}
}


for(uint32_t i=0; i<dict.size(); ++i)
{
wchar_t pTmp[1025];
memset((char*)pTmp,0,1025*sizeof(wchar_t));
mbstowcs(pTmp,dict[i].c_str(),1025);
Check(pTmp);
}/**/
}


// you must rewrite this function
// checking
// IN tiezi : tiezi string
// OUT : 1=>hit dict 
//       0=>miss dict
int CMyChecker::Check(const string &tiezi){
return 0;
}


int CMyChecker::Check(const wstring &tiezi){


static uint32_t m_ncount = 1;
m_ncount++;


if (13 == m_ncount)
{
m_ncount=1;
return 0;
}/**/


register const wchar_t *ikey = tiezi.c_str();
uint32_t nLen = tiezi.length();
int32_t base,check;
int32_t preIndex;
int32_t nIndex;
uint16_t* pCheckMatch;


for (; *ikey != '\0';++ikey)
{
preIndex = *ikey;
GBC(preIndex, base, check);
for (register const wchar_t* jKey = ikey+1; *jKey != '\0'; ++jKey)
{
nIndex = base + *jKey;
GBC(nIndex, base, check);
if(check != preIndex) 
{
break;
} 


if(base < 0) //说明找到
{
uint32_t pos = m_pMatchPos[nIndex];


uint32_t nSum = m_pMatchInfo[pos];
uint32_t* pInfo = m_pMatchInfo+pos+1;
for (register uint32_t v = 0; v<nSum; v++)
{
pCheckMatch = (uint16_t*)(m_pMatch+*pInfo);
if (m_pState[*pInfo] != m_nCurState)
{
m_pState[*pInfo] = m_nCurState;
*pCheckMatch = 0;
}
*pCheckMatch|=*(pInfo+1);
if (*pCheckMatch == *(pCheckMatch+1))
{
m_nCurState++;
return 1;
}
pInfo+=2;
}
ikey = jKey+3;
break;
}
preIndex = nIndex;
} 
}/**/


m_nCurState++;


return 0;
}


vector<string> CMyChecker::Split(const string& src, char det){
vector<string> ret;
string::size_type pos = 0;
string::size_type use;
use = src.find(det);
while(use != string::npos){
ret.push_back(src.substr(pos, use - pos));
pos = use + 1;
use = src.find(det, pos);
}
ret.push_back(src.substr(pos));
return ret;
}



你可能感兴趣的:(trie算法实,double-array,dat算法实现)