abp竞赛-之-文本文件词频查询 优化报告
HouSisong@GMail. 2007.03.15com
tag: abp,单词统计,比赛,hash,速度优化,优化报告
摘要: 以前参加过几次abp论坛的比赛 http://www.allaboutprogram.com/bb (现在的www.cpper.com/c)
其中的一个竞赛的题目是《文本文件词频查询》,本文章把自己的参赛代码的优化的思路整理出来;
很多时候优化后的版本最高达到了STL实现版本的20倍!
(2007.04.09 确认从MFC移植过来的时候引入了一个bug,“FILE* file=fopen(argv[1], "r" ); ”应该为“FILE* file=fopen(argv[1], "rb" ); ” 找了我好久:( 谢谢 )
(2007.03.17修正一个在vc2005编译器下访问vector的bug,将代码 “TNode** end=&(_vbase[_hash_power]); ” 改为 “base_t::iterator end=_vbase.end();” )
(abp现在还能够访问(只读),会员很多都“搬迁”到了 www.cpper.com/c :)
文本文件词频查询 竞赛要求:
OS:Windows
2000
/
XP
Compiler:VC6
/
VC.net
/
VC.net
2003
评判标准:正确性
+
速度
截止时间:2003年10月11日前(含)
方法:每个人可以多次提交。每次提交完了,我会告诉你你的成绩和最快的人的成绩。
内容:
一个文件,仅由大小写字母,空格和换行符组成。我们称一个词为连续的大小写字符,两边是空格或者文件头
/
尾。词大小写敏感。
某个词的词频是这个词在这个文件里面出现的次数。
要求,输入一个文件(至少有一个词,并且最大词频的词只有一个),输出那个词频最大的词。
譬如,输入:
aaa bbb
ccc ddd
aaa
输出:
aaa
补充一句:文件可能非常大。(xxxM,xG)
还有就是,文件中不会出现TAB。
(测试程序的时候,我们将vc目录中的源代码文件合成了一个数据文件来作为测试数据)
一个“标准”C++实现版本: (可以作为一个STL使用的实例:)
#pragma
warning ( disable : 4786 )
#include
<
iostream
>
#include
<
fstream
>
#include
<
string
>
#include
<
map
>
#include
<
time.h
>
using
namespace
std;
int
main(
int
argc,
char
*
argv[])
{
//
assert(argc==2);
clock_t start
=
clock();
const
char
*
file_name
=
argv[
1
];
ifstream in_file(file_name);
map
<
string
,
int
>
word_table;
string
max_word;
long
max_count
=
0
;
string
word;
while
(in_file
>>
word)
{
long old_count_inc=(++word_table[word]);
if(old_count_inc>max_count)
{
max_word=word;
max_count=old_count_inc;
}
}
cout
<<
"
Word:
"
<<
max_word
<<
"
Count:
"
<<
max_count
<<
endl;
cout
<<
"
Seconds =
"
<<
( (
double
)(clock()
-
start)
/
CLOCKS_PER_SEC )
<<
endl;
return
0
;
}
我用的测试编译器vc6.0 , CPU赛扬2.0G
下面的代码很多时候速度是上面的版本的20倍,源代码如下(优化说明在代码之后);
(我以前提交的代码使用了MFC库,为了容易编译和理解,我做了一些代码调整,去除MFC依赖,把一个复杂的代码循环展开删除了,可能慢了10%)
#pragma
warning ( disable : 4786 )
#include
<
stdio.h
>
#include
<
time.h
>
#include
<
iostream
>
#include
<
string
>
#include
<
vector
>
#include
<
algorithm
>
namespace
{
class
CMyAllot
{
enum
{ chunk_size
=
1024
*
256
};
//
块大小
char
*
_cur;
char
*
_end;
std::vector
<
char
*>
_vector;
void
*
_new_else(unsigned
int
size);
public
:
CMyAllot() :_end(
0
),_cur(
0
) { }
virtual
~
CMyAllot() {
if
(
!
_vector.empty()) DelAll(); }
inline
void
*
_fastcall New(unsigned
int
size)
{
size
=
((size
+
3
)
>>
2
<<
2
);
//
4字节边界对齐
if
((
int
)size
<
(_end
-
_cur))
//
够用
{
char
*
result
=
_cur;
_cur
+=
size;
return
result;
}
else
//
不够用
return
_new_else(size);
}
void
DelAll()
{
for
(
int
i
=
0
;i
<
(
int
)_vector.size();
++
i)
delete [] (_vector[i]);
_vector.clear();
}
};
void
*
CMyAllot::_new_else(unsigned
int
size)
{
if
(size
>
(chunk_size
>>
2
))
//
不够用,而且需要的空间较大
{
char
*
result
=
new
char
[size];
char
*
old_back
=
_vector.back();
_vector[_vector.size()
-
1
]
=
result;
_vector.push_back(old_back);
return
result;
}
else
//
不够用,开辟新的空间
{
char
*
result
=
new
char
[chunk_size];
_cur
=
result
+
size;
_end
=
result
+
chunk_size;
_vector.push_back(result);
return
result;
}
}
struct
TNode
//
hash表使用的节点类型(链表)
{
TNode
*
pNext;
unsigned
int
count;
char
str[
1
];
//
不一定只有一个字节,会根据字符串分配空间
struct
TComp
//
返回时的排序准则
{
bool
operator
()(
const
TNode
*
l,
const
TNode
*
r)
{
if
((l
->
count)
==
(r
->
count))
{
return
std::
string
(
&
l
->
str[
0
])
<
(
&
r
->
str[
0
]);
}
else
return
(l
->
count)
>
(r
->
count);
}
};
};
inline unsigned
int
_fastcall hash_value(
char
*
begin,
char
*
end)
{
unsigned
int
result
=
0
;
do
{
result
=
5
*
result
+
(
*
begin);
//
利用asm: lea reg0,[reg1*4+reg1],并且5是质数
}
while
((
++
begin)
!=
end);
return
result;
}
inline unsigned
int
_fastcall hash_value(
char
*
pstr)
{
unsigned
int
result
=
0
;
do
{ result
=
5
*
result
+
(
*
pstr); ;
//
利用asm: lea reg0,[reg1*4+reg1],并且5是质数
}
while
((
*
(
++
pstr)));
return
result;
}
//
测试字符串是否相同, 如果需要不区分大小写,修改这个函数和hash函数就可以了
inline
bool
_fastcall test_str_EQ(
char
*
begin,
char
*
end,
char
*
str)
{
//
for (;begin!=end;++begin,++str)
//
if ( (*begin)!=*(str) ) return false;
do
{
if
( (
*
begin)
!=*
(str) )
return
false
;
++
begin;
++
str;
}
while
(begin
!=
end);
return
true
;
}
}
class
CHashSet
{
typedef std::vector
<
TNode
*>
base_t;
inline unsigned
int
hash_index(
char
*
begin,
char
*
end)
const
{
return
hash_value(begin,end)
&
(_hash_mask); }
inline unsigned
int
hash_index(
char
*
pstr)
const
{
return
hash_value(pstr)
&
(_hash_mask); }
void
resize();
void
_fastcall move_insert(base_t
&
v,TNode
*
pOldNode)
const
;
TNode
*
_fastcall NewNode(
char
*
begin,
char
*
end);
void
Sort(base_t
&
v,unsigned
int
sortCount);
unsigned
int
_hash_power;
unsigned
int
_hash_mask;
unsigned
int
_node_count;
base_t _vbase;
CMyAllot _allot;
void
_fastcall else_insert(TNode
*
pNode,
char
*
begin,
char
*
end);
public
:
CHashSet();
virtual
~
CHashSet();
unsigned
int
size()
const
{
return
_node_count; }
unsigned
int
sum();
void
_fastcall insert(
char
*
begin,
char
*
end);
void
GetStrList(std::ostream
&
cout,unsigned
int
sortCount);
};
CHashSet::CHashSet()
:_hash_power(
2
),_vbase((unsigned
int
)(_hash_power),(TNode
*
)
0
)
//
注意次序
{
_node_count
=
0
;
_hash_mask
=
_hash_power
-
1
;
//
_hash_power=1<<n;
}
CHashSet::
~
CHashSet()
{
_allot.DelAll();
}
unsigned
int
CHashSet::sum()
{
unsigned
int
sum
=
0
;
if
(_node_count
>
0
)
{
base_t::iterator end
=
_vbase.end();
for
(base_t::iterator i
=
_vbase.begin();i
<
end;
++
i)
{
TNode
*
pNode
=
(
*
i);
while
(pNode
!=
0
)
{
sum
+=
pNode
->
count;
pNode
=
pNode
->
pNext;
}
}
}
return
sum;
}
void
_fastcall CHashSet::insert(
char
*
begin,
char
*
end)
{
unsigned
int
index
=
hash_index(begin,end);
TNode
*
pNode
=
_vbase[index];
if
(
!
pNode)
//
节点还没有使用
{
_vbase[index]
=
NewNode(begin,end);
++
_node_count;
}
else
{
if
(test_str_EQ(begin,end,pNode
->
str))
//
累加
++
(pNode
->
count);
else
else_insert(pNode,begin,end);
}
}
void
_fastcall CHashSet::else_insert(TNode
*
pNode,
char
*
begin,
char
*
end)
{
while
(
true
)
{
if
(
!
(pNode
->
pNext))
{
pNode
->
pNext
=
NewNode(begin,end);
++
_node_count;
if
(_node_count
>=
(_hash_power))
resize();
break
;
}
else
if
(test_str_EQ(begin,end,pNode
->
pNext
->
str))
{
++
(pNode
->
pNext
->
count);
break
;
}
pNode
=
pNode
->
pNext;
};
}
void
_fastcall CHashSet::move_insert(base_t
&
v,TNode
*
pOldNode)
const
{
TNode
*&
pNode
=
v[hash_index(pOldNode
->
str)];
pOldNode
->
pNext
=
0
;
if
(
!
pNode)
//
节点还没有使用
{
pNode
=
pOldNode;
}
else
{
if
(
!
pNode
->
pNext)
{
pNode
->
pNext
=
pOldNode;
}
else
{
TNode
*
pListNode
=
pNode
->
pNext;
while
(pListNode
->
pNext
!=
0
)
{ pListNode
=
pListNode
->
pNext; }
pListNode
->
pNext
=
pOldNode;
}
}
}
TNode
*
_fastcall CHashSet::NewNode(
char
*
begin,
char
*
end)
{
TNode
*
pNode
=
(TNode
*
)(_allot.New(
sizeof
(TNode)
+
end
-
begin));
pNode
->
pNext
=
0
;
pNode
->
count
=
1
;
char
*
i
=
pNode
->
str;
//
for (;begin!=end;++i,++begin)
//
(*i)=(*begin);
do
{
(
*
i)
=
(
*
begin);
++
i,
++
begin;
}
while
(begin
!=
end);
(
*
i)
=
char
(
0
);
return
pNode;
}
void
CHashSet::resize()
{
if
(_node_count
>=
(_hash_power))
{
base_t::iterator end
=
_vbase.end();
_hash_power
<<=
2
;
_hash_mask
=
(_hash_power)
-
1
;
base_t new_vbase(_hash_power,(TNode
*
)
0
);
for
(base_t::iterator i
=
_vbase.begin();i
!=
end;
++
i)
{
TNode
*
pNode
=
(
*
i);
while
(pNode
!=
0
)
{
TNode
*
temp
=
pNode
->
pNext;
move_insert(new_vbase,pNode);
pNode
=
temp;
}
}
_vbase.swap(new_vbase);
}
}
///
/
void
CHashSet::Sort(base_t
&
v,unsigned
int
sortCount)
{
if
(sortCount
==
1
)
{
v.resize(
1
);
base_t::iterator end
=
_vbase.end();
TNode
*
maxNode
=
_vbase[
0
];
TNode::TComp op;
for
(base_t::iterator i
=
_vbase.begin();i
!=
end;
++
i)
{
TNode
*
pNode
=
(
*
i);
while
(pNode
!=
0
)
{
if
( (maxNode
==
0
)
||
(op(pNode,maxNode)) )
maxNode
=
pNode;
pNode
=
pNode
->
pNext;
}
}
v[
0
]
=
maxNode;
}
else
{
v.resize(_node_count);
int
index
=
0
;
if
(_node_count
>
0
)
{
TNode
**
end
=&
(_vbase[_hash_power]);
for
(TNode
**
i
=&
(_vbase[
0
]);i
!=
end;
++
i)
{
TNode
*
pNode
=
(
*
i);
while
(pNode
!=
0
)
{
v[index]
=
pNode;
++
index;
pNode
=
pNode
->
pNext;
}
}
}
std::partial_sort(v.begin(),v.begin()
+
sortCount,v.end(),TNode::TComp());
}
}
void
CHashSet::GetStrList(std::ostream
&
cout,unsigned
int
sortCount)
{
if
(_node_count
>=
1
)
{
if
(sortCount
==
0
)
sortCount
=
_node_count;
else
if
(_node_count
<
sortCount)
sortCount
=
_node_count;
base_t v;
Sort(v,sortCount);
for
(
int
i
=
0
;i
<
(
int
)sortCount;
++
i)
{
std::cout
<<
"
单词:
"
<<
(
&
(v[i]
->
str[
0
]))
<<
"
计数:
"
<<
(v[i]
->
count)
<<
std::endl;
}
}
}
class
CWords
{
private
:
enum
{ cibuf_size
=
4096
};
//
缓冲区最佳大小
int
buf_size;
//
动态缓冲区大小
char
*
pBuf;
//
指向缓冲区
static
void
CreateGainTab();
//
构造“词”分析用的表
int
privateGainWord(
int
dx,
int
start_offset,
bool
isEndGain);
inline
int
GainWord(
int
dx,
int
start_offset);
//
从缓冲区获取词;
inline
void
endGainWord(
int
dx,
int
start_offset);
//
从缓冲区获取词,处理文件尾;
void
_fastcall PushWord(
char
*
begin,
char
*
end);
__int64 _CPUCount;
CHashSet _hash_set;
public
:
CWords();
virtual
~
CWords();
void
toDo(FILE
*
file);
//
循环读取文件数据到内存缓冲区
void
GetResult(std::ostream
&
cout,unsigned
int
sortCount);
};
namespace
{
static
unsigned
int
GainTab[
256
];
//
进行词法分析的表
}
//
构造“词”分析用的表
void
CWords::CreateGainTab()
{
//
static
bool
IsDo
=
false
;
if
(IsDo)
return
;
for
(
int
i
=
0
;i
<
256
;
++
i)
{
if
( ((i
>=
'
A
'
)
&&
(i
<=
'
Z
'
))
||
((i
>=
'
a
'
)
&&
(i
<=
'
z
'
))
//
|| (i=='_')
//
|| ((i>='0')&&(i<='9'))
)
GainTab[i]
=
unsigned
int
(
-
1
);
else
GainTab[i]
=
0
;
}
IsDo
=
true
;
}
CWords::CWords()
{
}
CWords::
~
CWords()
{
}
#define
asm __asm
__declspec( naked ) __int64 CPUCycleCounter()
//
获取当前CPU周期计数(CPU周期数)
{
asm
{
RDTSC
//
0F 31
//
eax,edx
ret
}
}
//
循环读取文件数据到内存缓冲区
void
CWords::toDo(FILE
*
file)
{
_CPUCount
=
::CPUCycleCounter();
std::vector
<
char
>
BufData(cibuf_size);
buf_size
=
BufData.size();
pBuf
=&
BufData[
0
];
//
get file length
fseek(file,
0
,SEEK_END);
int
file_length
=
ftell(file);
fseek(file,
0
,SEEK_SET);
int
file_pos
=
0
;
CreateGainTab();
int
dx
=
0
;
int
start_offset
=
0
;
while
(
true
)
{
if
(file_pos
+
(buf_size
-
dx)
<=
file_length)
{
fread(pBuf
+
dx,buf_size
-
dx,
1
,file);
file_pos
+=
(buf_size
-
dx);
dx
=
GainWord(dx,start_offset);
start_offset
=
0
;
if
(dx
<
0
)
//
处理超长单词
{
start_offset
=
buf_size
+
dx;
//
放大缓冲区
dx
=
buf_size;
BufData.resize(dx
*
2
);
buf_size
=
BufData.size();
pBuf
=&
BufData[
0
];
}
else
//
if ( (dx<(cibuf_size>>1)) && (buf_size>(cibuf_size<<1)) )
{
//
BufData.resize(cibuf_size);
//
减小缓冲区
//
pBuf=&BufData[0];
//
buf_size=BufData.size();
}
}
else
{
int
bordercount
=
(
int
)(file_length
-
file_pos);
if
(bordercount
>
0
)
{
fread(pBuf
+
dx,bordercount,
1
,file);
buf_size
=
dx
+
bordercount;
//
file_pos+=(bordercount);
endGainWord(dx,start_offset);
}
break
;
//
end while
}
}
_CPUCount
=
::CPUCycleCounter()
-
_CPUCount;
}
int
CWords::privateGainWord(
int
dx,
int
start_offset,
bool
isEndGain)
{
char
*
pStart
=
pBuf
+
start_offset;
char
*
pEnd
=
pBuf
+
buf_size;
int
IsInWord
=
(dx
!=
0
)
?
int
(
-
1
):
0
;
//
是否处于“词”中
char
*
i
=
pBuf
+
dx;
for
(;i
!=
pEnd;
++
i)
{
if
(IsInWord
^
GainTab[
*
(unsigned
char
*
)i])
{
if
(IsInWord)
PushWord(pStart,i);
else
pStart
=
i;
IsInWord
=
(
~
IsInWord);
}
}
/////////
dx
=
0
;
if
(IsInWord)
{
if
(isEndGain)
PushWord(pStart,pEnd);
//
最末尾的一个词
else
{
dx
=
pEnd
-
pStart;
if
(dx
>
(buf_size
>>
1
))
//
超长单词特殊处理
dx
=
(
-
dx);
//
特殊标记!
else
{
for
(
int
i
=
0
;i
<
dx;
++
i)
//
把没有处理完的单词拷贝到缓冲区开头
pBuf[i]
=
pStart[i];
}
}
}
return
dx;
}
int
CWords::GainWord(
int
dx,
int
start_offset)
{
return
privateGainWord(dx,start_offset,
false
);
}
void
CWords::endGainWord(
int
dx,
int
start_offset)
{
privateGainWord(dx,start_offset,
true
);
}
void
CWords::GetResult(std::ostream
&
cout, unsigned
int
sortCount)
{
std::cout
<<
"
无重复单词数:
"
<<
_hash_set.size()
<<
"
单词总数:
"
<<
_hash_set.sum()
<<
std::endl;
std::cout
<<
"
CPU周期计数:
"
<<
(
long
)_CPUCount
<<
std::endl;
_hash_set.GetStrList(cout,sortCount);
}
inline
void
_fastcall CWords::PushWord(
char
*
begin,
char
*
end)
{
_hash_set.insert(begin,end);
}
//////////////////////////////
/
int
CreateTxtFile(
char
*
argv[]);
int
toWork(
int
argc,
char
*
argv[]);
const
char
sParameter []
=
"
Cpt_hss filename [/N]
"
;
//
主程序
int
main(
int
argc,
char
*
argv[])
{
if
(argc
<=
1
)
{
std::cout
<<
(
"
请输入文件名称!
"
);
std::cout
<<
sParameter;
std::cout
<<
std::endl;
return
0
;
}
if
(std::
string
(argv[
1
])
==
"
/?
"
)
{
std::cout
<<
(
"
统计文件中单词出现频率。
"
);
std::cout
<<
(sParameter);
std::cout
<<
(
"
filename 指定需要进行统计的文件的名称
"
);
std::cout
<<
(
"
[/N] 显示出现频率最高的前N个单词;
"
);
std::cout
<<
(
"
如果单词出现频率相同,则按字母顺序排列;
"
);
std::cout
<<
(
"
N默认为1;
"
);
std::cout
<<
(
"
当N=0时,表示全部显示。
"
);
std::cout
<<
std::endl;
return
0
;
}
return
toWork(argc,argv);
}
int
toWork(
int
argc,
char
*
argv[])
{
clock_t start
=
clock();
FILE
*
file
=
fopen(argv[
1
],
"
rb
"
);
if
(file
==
0
)
{
std::cout
<<
(
"
打开文件时发生错误!
"
);
std::cout
<<
(sParameter);
return
0
;
}
unsigned
int
sortCount
=
1
;
if
(argc
==
3
)
sortCount
=
atoi(argv[
2
]
+
1
);
CWords words;
words.toDo(file);
fclose(file);
words.GetResult(std::cout,sortCount);
std::cout
<<
"
Seconds =
"
<<
( (
double
)(clock()
-
start)
/
CLOCKS_PER_SEC )
<<
std::endl;
return
0
;
}
重点优化说明: (这是本篇文章的重点,讲解一些基本的优化策略)
1.在读取文件方面,使用了一个自己管理的内存缓冲区来读取文件的数据;
(这样处理以后读文件占的时间约占总时间的1/7,还可以进一步优化:
进一步改进方案a:可以考虑用另一个线程异步来加载文件数据(当前处理大量文件数据的高效方案);
进一步改进方案b:如果文件不太大可以考虑使用内存映射技术来优化这一块,代码也简单很
多,而单词的表示也可以采用一个指针加一个长度(或者用头尾两个指针,或者一个指针+哨兵
位(推荐))来表示,从而避免一次深拷贝)
2.建立了一个查询表GainTab[256]用来判断一个字母是否是单词还是空白区域;
比如:可以把( ((C>='A')&&(C<='Z'))||((C>='a')&&(C<='z')) ) 简写为 ( GainTab[C]!=0 )
(其实也可以建立一个64k的表来捕捉状态,同时用两个字节来查表...)
3.把查找单词的扫描过程理解为从单词区域到空白区域的状态转换(这句可能不好理解);
比如一般常见的实现伪代码:
char
*
i
=
pBegin;
while (i!=pEnd)
{
while((i!=pEnd)&&(!GainTab[*(unsigned char*)i])) //寻找到单词开头
++i;
pStart=i;
while((i!=pEnd)&&(GainTab[*(unsigned char*)i])) //寻找该单词结束位置
++i;
if (pStart!=i)
PushWord(pStart,i);
}
我的代码:
for
(
char
*
i
=
pBegin;i
!=
pEnd;
++
i)
{
if
(IsInWord
^
GainTab[
*
(unsigned
char
*
)i])
//
捕捉所属区域状态的变化
{
if
(IsInWord)
PushWord(pStart,i);
else
pStart
=
i;
IsInWord
=
(
~
IsInWord);
}
}
该算法处理两个状态:是否在单词中、“是否在单词中”的状态是否改变;
从而消除了内部的一个循环框架,这在单词和空际较小时将带来更多好处;
(在本程序中可能所起作用不大,这里耗的时间不多,反而调用PushWord的花掉的时间更多)
(还有一个有用的见解:“经过3个字节最多能够计数一个单词”,比如利用这个观点可以建立
2字节或3字节的查询表(表的大小的取舍需要考虑CPU的缓冲区大小),同时处理更多的字节)
4.为了优化单词使用的内存,减少动态内存分配,自定义了一个CMyAllot类来管理内存的分配
5.我使用了一个自定义的hash表CHashSet(准确点应该叫做map)来储存找到的单词(hash表具有平
均常数时间的单词查找能力),表的大小会随着无重复单词数的增加而动态增长:某个HashItem
不可用时,会把新的单词加到HashItem后面,即HashItem形成一个单向list,当hash表的负债超过
某个阈值的时候,就会增大表的大小,然后所有的元素重新转移到新的表;
6.我的hash表的大小只可能为2的整数次方,所以hash值在映射到HashItem的序号时可以使用快速
的&运算(hash_value&_hash_mask); 等价于(hash_value%hash_size) , 优化掉一次求余运算
(求余和除法都是很慢的操作)
补充: 我尝试过把char字符流当作wchar_t* 流来处理,希望提高吞吐量
但为保证结果正确代码逻辑变得稍微复杂了一些,结果在我的机子上速度几乎没有变!