题目链接:Click here~~
题意:
给 n 个字符串的集合,有些字符串是要必须在主串出现,有些必须不在主串出现,其他串有分数加成。要求删除最少字符满足要求且使分数最大。
解题思路:
此题有两个优化目标。
令 dp1[i][j][mask] 表示考虑前 i 个字符,跑到节点 j,选取必须字符串的集合为 mask 的最少删除字符数。令 dp2[i][j][mask] 表示在 dp1 相应状态下,最多能获取的分数。
跑了900+ms,不算快。有种做法是只记录一种状态,记录删除相应个数的字符得到的分数,能跑200+ms。
#include <queue> #include <stdio.h> #include <string.h> #include <algorithm> using namespace std; #define CLR(a,v) memset(a,v,sizeof(a)) namespace Trie { const int N = 16 * 100; const int Size = 26; int top,m; struct Node{ Node *next[Size], *f; int ended_val; int ended_mask; bool ended_bad; }node[N], *root; inline Node* new_node() { node[top].ended_val = 0; node[top].ended_mask = 0; node[top].ended_bad = false; CLR(node[top].next,NULL); return &node[top++]; } void init() { top = m = 0; root = new_node(); } void insert(char *s,int val) { Node *u = root; for(int i=0;s[i];i++) { int id = s[i] - 'a'; if(u->next[id] == NULL) u->next[id] = new_node(); u = u->next[id]; } if(val == 999) u->ended_mask = 1 << (m++); else if(val == -999) u->ended_bad = true; else u->ended_val = val; } } namespace ACam { using namespace Trie; void get_fail() { queue<Node*> Q; for(int i=0;i<Size;i++) { Node *&ch = root->next[i]; if(!ch) ch = root; else { ch->f = root; Q.push(ch); } } while(!Q.empty()) { Node *cur = Q.front();Q.pop(); for(int i=0;i<Size;i++) { Node *&ch = cur->next[i]; if(!ch) ch = cur->f->next[i]; else { ch->f = cur->f->next[i]; ch->ended_val += ch->f->ended_val; ch->ended_bad |= ch->f->ended_bad; ch->ended_mask |= ch->f->ended_mask; Q.push(ch); } } } } int dp1[2][N][1<<8]; int dp2[2][N][1<<8]; inline bool better(int cur,int j,int k,int nxt,int jj,int kk,bool del,int add) { return dp1[cur][j][k]+del < dp1[nxt][jj][kk] || dp1[cur][j][k]+del == dp1[nxt][jj][kk] && dp2[cur][j][k]+add > dp2[nxt][jj][kk]; } void solve(char *s) { CLR(dp1[0],63); CLR(dp2[0],0); const int inf = dp1[0][0][0]; dp1[0][0][0] = 0; int cur = 0 , nxt = 1; pair<int,int> ans = make_pair(inf,0); for(int i=0;s[i];i++) { CLR(dp1[nxt],63); CLR(dp2[nxt],0); for(int j=0;j<top;j++) { for(int mask=0;mask<(1<<m);mask++) { if(dp1[cur][j][mask] == inf) continue; if(better(cur,j,mask,nxt,j,mask,true,0)) { dp1[nxt][j][mask] = dp1[cur][j][mask] + 1; dp2[nxt][j][mask] = dp2[cur][j][mask]; } int k = s[i] - 'a'; int jj = node[j].next[k] - node; if(node[jj].ended_bad) continue; int __mask = mask | node[jj].ended_mask; if(better(cur,j,mask,nxt,jj,__mask,false,node[jj].ended_val)) { dp1[nxt][jj][__mask] = dp1[cur][j][mask]; dp2[nxt][jj][__mask] = dp2[cur][j][mask] + node[jj].ended_val; } } } cur ^= 1 , nxt ^= 1; } int full_mask = (1<<m) - 1; for(int j=0;j<top;j++) if(dp1[cur][j][full_mask] < ans.first || dp1[cur][j][full_mask] == ans.first && dp2[cur][j][full_mask] > ans.second) ans = make_pair(dp1[cur][j][full_mask],dp2[cur][j][full_mask]); if(ans.first == inf) puts("Banned"); else printf("%d %d\n",ans.first,ans.second); } } char str[105]; int main() { int T,n, ncase = 0; scanf("%d",&T); while(T--) { ACam::init(); scanf("%d",&n); while(n--) { int val; scanf("%s%d",str,&val); ACam::insert(str,val); } ACam::get_fail(); scanf("%s",str); printf("Case %d: ",++ncase); ACam::solve(str); } return 0; }