hdu 4117 GRE Words (AC自动机+线段树)
题意:给出n个字符串,每个字符串有一个权值,我们从中拿出若干个来,这若干个字符串,前一个必须是后一个的子串,问,我们能拿出的这若干个串的权值和最大是多少。
解题思路:AC自动机。字符串匹配算法,大概就是kmp,ac自动机,后缀数组,后缀自动机这么几种了。对于这题,我们很容易想到暴力dp,用kmp去匹配,总复杂度可以做到o(n^2+2*m)(n为字符串个数,m为所有字符串的总长),但这样还不够,超时妥妥的。那就要考虑怎么维护这个dp了。把自动机建好后,从前往后一个个的,去匹配。在自动机上匹配的过程中,我们可以发现,这个串走到的所有的节点,以及这些节点一路往根fail的节点,都是这个串的子串,那么以这个串结尾的最大值,就是这所有节点的最大值,加上当前枚举的串的权值了。然而,我们如果在匹配的过程中,一直fail走,那么,超时也是妥妥的,m^2的数据还是能构造出来的。因此,我们要考虑的是怎么维护这个最大值了。对于fail指针,理解深刻的话,应该很容易想到,我们根据fail指针连边,会构成一颗fail树,那么对于当前串走到的某一个节点,他的最大值,就是fail树上,从他到根的这条链上的最大值了。那么当我们得到以某一个字符串为结尾的序列,能得到的最大值时,我们能影响到哪些节点的最值呢?很显然,fail树上,该节点的子树都会受到影响,那么就好办了。一颗子树,根据dfs序,是一段连续的区间,根据这个,我们就可以构造线段树了。区间更新,单点询问最大值就ok了。
另外,这题要处理的细节还是不少的。。。
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<stdio.h> #include<string.h> #include<algorithm> #include<queue> #define lson l , m , rt << 1 #define rson m + 1 , r , rt << 1 | 1 using namespace std ; const int maxn = 333333 ; const int INF = -111111111 ; int li[maxn] , num[maxn] ; int le[maxn] , ri[maxn] ; int col[maxn<<2] , mx[maxn<<2] , to[maxn] ; int dp[maxn] ; struct Edge { int t , next ; } edge[maxn<<1] ; int head[maxn] , T , cnt ; void new_edge ( int a , int b ) { edge[T].t = b ; edge[T].next = head[a] ; head[a] = T ++ ; } void push_up ( int rt ) { mx[rt] = max ( mx[rt<<1] , mx[rt<<1|1] ) ; } void push_down ( int rt ) { if ( col[rt] ) { col[rt<<1] = max ( col[rt<<1] , col[rt] ) ; col[rt<<1|1] = max ( col[rt<<1|1] , col[rt] ) ; mx[rt<<1] = max ( mx[rt<<1] , col[rt] ) ; mx[rt<<1|1] = max ( mx[rt<<1|1] , col[rt] ) ; col[rt] = 0 ; } } void build ( int l , int r , int rt ) { col[rt] = mx[rt] = 0 ; if ( l == r ) return ; int m = ( l + r ) >> 1 ; build ( lson ) ; build ( rson ) ; push_up ( rt ) ; } int query ( int a , int l , int r , int rt ) { if ( l == r ) return mx[rt] ; push_down ( rt ) ; int m = ( l + r ) >> 1 ; int ret = 0 ; if ( a <= m ) ret = query ( a , lson ) ; else ret = query ( a , rson ) ; push_up ( rt ) ; return ret ; } void update ( int a , int b , int c , int l , int r , int rt ) { if ( a <= l && r <= b ) { col[rt] = max ( c , col[rt] ) ; mx[rt] = max ( c , mx[rt] ) ; return ; } push_down ( rt ) ; int m = ( l + r ) >> 1 ; if ( a <= m ) update ( a , b , c , lson ) ; if ( m < b ) update ( a , b , c , rson ) ; push_up ( rt ) ; } struct ac_auto { queue<int> Q ; int tot ; int c[26][maxn] , fail[maxn] , val[maxn] ; inline int new_node () { int i ; for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ; fail[tot] = val[tot] = 0 ; return tot ++ ; } void insert ( char *s , int id ) { int len = strlen ( s ) ; int i , now = 0 ; for ( i = 0 ; i < len ; i ++ ) { int k = s[i] - 'a' ; if ( !c[k][now] ) c[k][now] = new_node () ; now = c[k][now] ; } to[id] = now ; } void get_fail () { int u = 0 , i , e , j ; for ( i = 0 ; i < 26 ; i ++ ) if ( c[i][u] ) Q.push ( c[i][u] ) ; while ( !Q.empty () ) { u = Q.front () ; Q.pop () ; for ( i = 0 ; i < 26 ; i ++ ) { if ( c[i][u] ) { e = c[i][u] ; j = fail[u] ; fail[e] = c[i][j] ; Q.push ( e ) ; } else c[i][u] = c[i][fail[u]] ; } } } void init () { tot = 0 ; new_node () ; } void dfs ( int u , int fa ) { int i ; le[u] = ++ cnt ; for ( i = head[u] ; i != -1 ; i = edge[i].next ) { int v = edge[i].t ; if ( v == fa ) continue ; dfs ( v , u ) ; } ri[u] = cnt ; } void fuck () { int i ; cnt = 0 ; memset ( head , -1 , sizeof ( head ) ) ; for ( i = 1 ; i < tot ; i ++ ) { new_edge ( fail[i] , i ) ; new_edge ( i , fail[i] ) ; } dfs ( 0 , -1 ) ; } void work ( char *s , int t ) { int i , now = 0 ; int ans = 0 , mx = 0 ; build ( 1 , cnt , 1 ) ; for ( i = 0 ; i < t ; i ++ ) { int k = s[i] - 'a' ; now = c[k][now] ; mx = max ( mx , query ( le[now] , 1 , cnt , 1 ) ) ; if ( li[i] ) { int add = num[li[i]] > 0 ? num[li[i]] : 0 ; ans = max ( ans , mx + add ) ; update ( le[now] , ri[now] , mx + add , 1 , cnt , 1 ) ; now = mx = 0 ; } } printf ( "%d\n" , ans ) ; } } ac ; char s[maxn] ; char s1[maxn] ; int main () { int cas , ca = 0 , i ; scanf ( "%d" , &cas ) ; while ( cas -- ) { int n , j ; scanf ( "%d" , &n ) ; ac.init () ; int t = 0 ; T = 0 ; memset ( li , 0 , sizeof ( li ) ) ; for ( i = 1 ; i <= n ; i ++ ) { scanf ( "%s%d" , s , &num[i] ) ; ac.insert ( s , i ) ; int len = strlen ( s ) ; for ( j = 0 ; j < len ; j ++ ) s1[t++] = s[j] ; li[t-1] = i ; } ac.get_fail () ; ac.fuck () ; printf ( "Case #%d: " , ++ ca ) ; ac.work ( s1 , t ) ; } return 0 ; }