题目链接
给 k k k个字符串 t 1 , t 2 , . . . t k t_1,t_2,...t_k t1,t2,...tk, t i t_i ti有权值 c i c_i ci.令 F ( T , t ) F(T,t) F(T,t)表示字符串 T T T中包含多少个 t t t, G ( T ) = ∑ i = 1 k F ( T , t i ) ∗ c i G(T)=\sum_{i=1}^kF(T,t_i)*c_i G(T)=∑i=1kF(T,ti)∗ci。
现在给出一个字符串 S S S, S S S中有最多14个位置是未知的,你可以在这些位置上填互不相同的字母 a − n a-n a−n,求 G ( S ) G(S) G(S)最大可以是多少。
∑ ∣ t i ∣ ≤ 1000 , ∣ S ∣ ≤ 5 e 4 , − 1 0 6 ≤ c i ≤ 1 0 6 \sum |t_i|\le 1000, |S|\le5e4,-10^6\le c_i \le 10^6 ∑∣ti∣≤1000,∣S∣≤5e4,−106≤ci≤106
注意到未知的位置较少,且必须要填互不相同的字母,这提示我们用状压DP去写。
而统计一些模板字符在一个字符串里面出现的次数和贡献,可以使用ac自动机求出。在这题中的障碍是那些未知的位置。
注意到 ∑ ∣ t i ∣ ≤ 1000 \sum|t_i|\le 1000 ∑∣ti∣≤1000,AC自动机最多有1000个结点。未知位置最多有14个,所以原本的串 S S S最多被分成15段已知的固定的串。
我们令 n x t [ u ] [ i ] nxt[u][i] nxt[u][i]表示ac自动机的结点 u u u跑一遍 S S S的第 i i i段串之后变成了结点 n x t [ u ] [ i ] nxt[u][i] nxt[u][i]。令 s u m [ u ] [ i ] sum[u][i] sum[u][i]表示这个过程中得到的贡献。
我们用 d p [ u ] [ m a s k ] , ( 假 设 m a s k 中 的 1 的 个 数 为 c n t ) dp[u][mask],(假设mask中的1的个数为cnt) dp[u][mask],(假设mask中的1的个数为cnt)表示:
处理完前 c n t cnt cnt个未知位置,使用的字符集合为 m a s k mask mask,当前位置为第 c n t + 1 cnt+1 cnt+1段的最后一个字母,在ac自动机上的位置为结点 u u u的情况下,得到的G的最大值.
它的转移如图表示:
先枚举当前使用的字符集合mask,然后枚举上一段的结尾走到了ac自动机的u,根据第cnt个位置填什么字符来转移:
转移的时候有三段贡献:
dp[ nxt[ch[u][i]][num] ][mask] =max(dp[ nxt[ch[u][i]][num] ][mask], dp[u][mask^(1<<i)]+cost[ch[u][i]]+sum[ch[u][i]][num]);
ac代码:
#include
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define fors(i, a, b) for(int i = (a); i < (b); ++i)
using namespace std;
const int maxn = 4e5 + 50;
int ch[maxn][15], fail[maxn];
ll cost[maxn], rt, tot = 0;
void ins(char *s, int val){
int p = rt;
while(*s){
int x = *s - 'a';
if(!ch[p][x]) {
ch[p][x] = ++tot;
}
p = ch[p][x];
s++;
}
cost[p] += val;
}
queue<int> q;
void get_fail()
{
while(q.size()) q.pop();
for(int i = 0; i < 15; ++i)
if(ch[rt][i]) q.push(ch[rt][i]), fail[ch[rt][i]] = rt;
else ch[rt][i] = rt;
while(q.size()){
int cur = q.front(); q.pop();
for(int i = 0; i < 15; ++i){
if(ch[cur][i]) {
fail[ ch[cur][i] ] = ch[ fail[cur] ][i];
q.push(ch[cur][i]);
cost[ch[cur][i]] += cost[ fail[ ch[cur][i] ] ];
}
else ch[cur][i] = ch[fail[cur]][i];
}
}
}
char t[1050];
void init(){
tot = 0; rt = ++tot;
int n; scanf("%d", &n);
fors(i, 0, n){
int x;
scanf("%s%d", t, &x); ins(t, x);
}
get_fail();
}
char s[maxn];
int pos[20], cnt = 0;
int nxt[1050][17];
ll sum[1050][17];
ll dp[1050][1<<14];
int cal(int x){int res = 0; while(x) res++, x-=lowbit(x); return res;}
void sol(){
scanf("%s", s);
int n = strlen(s);
pos[cnt++] = -1;
fors(i, 0, n) if(s[i] == '?') pos[cnt++] = i;
pos[cnt] = n;
fors(i, 0, cnt){
fors(u, 1, tot+1){
int p = u;
fors(j, pos[i]+1, pos[i+1]){
p = ch[p][s[j]-'a'];
sum[u][i] += cost[p];
}nxt[u][i] = p;
}
}
memset(dp, 0xcf, sizeof dp);
dp[nxt[rt][0]][0] = sum[rt][0];
ll ans = -1e18;
if(cnt == 1) ans = sum[rt][0];//if there is no "?"
fors(mask, 1, (1<<14)){
int num = cal(mask);
if(num > cnt-1) continue;
fors(u, 1, tot+1){
fors(i, 0, 14){
if(mask>>i&1){
dp[ nxt[ch[u][i]][num] ][mask] =
max(dp[ nxt[ch[u][i]][num] ][mask], dp[u][mask^(1<<i)]+cost[ch[u][i]]+sum[ch[u][i]][num]);
if(num == cnt-1) {
ans = max(ans, dp[ nxt[ch[u][i]][num] ][mask]);
}
}
}
}
}
cout<<ans<<endl;
}
int main()
{
init();
sol();
}