【AC自动机】 HDOJ 3341 Lost's revenge

AC自动机+状态压缩DP。。dp过程很容易想到。。但是状态不容易压缩,一个简单的想法是开个4维数组记录所有情况,但是显然空间开不下。。。所以我们需要找一个hash函数,进行状态的压缩。。。这里用变进制来进行hash,就像秒,分钟,小时。。那样子,注意一下dp过程考虑可达不可达,还有一个数取完不可以再取这个数。。。或者变进制每个数+2,就可以不考虑这个数取完的情况。

#include <iostream>
#include <sstream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <climits>
#define maxn 505
#define eps 1e-6
#define mod 10007
#define INF 99999999
#define lowbit(x) ((x)&(-(x)))
#define lson o<<1, L, mid
#define rson o<<1 | 1, mid+1, R
typedef long long LL;
using namespace std;

struct trie
{
	int next[maxn][4];
	int fail[maxn];
	int end[maxn];
	char s[maxn];
	queue<int> q;
	int hash[26];
	int top, now, root;
	
	int newnode(void)
	{
		end[top] = 0;
		fail[top] = -1;
		for(int i = 0; i < 4; i++)
			next[top][i] = -1;
		return top++;
	}
	void init(void)
	{
		top = 0;
		root = newnode();
		memset(hash, 0 ,sizeof hash);
		hash['A' - 'A'] = 0;
		hash['C' - 'A'] = 1;
		hash['G' - 'A'] = 2;
		hash['T' - 'A'] = 3;
	}
	void insert(void)
	{
		int len = strlen(s), i, k;
		now = root;
		for(i = 0; i < len; i++) {
			k = hash[s[i] - 'A'];
			if(next[now][k] == -1)
				next[now][k] = newnode();
			now = next[now][k];
		}
		end[now] += 1;
	}
	void build(void)
	{
		now = root;
		for(int i = 0; i < 4; i++)
			if(next[now][i] == -1)
				next[now][i] = root;
			else {
				fail[next[now][i]] =root;
				q.push(next[now][i]);
			}
		while(!q.empty()) {
			now = q.front();
			q.pop();
			if(end[fail[now]]) end[now] += end[fail[now]];
			for(int i = 0; i < 4; i++)
				if(next[now][i] == -1)
					next[now][i] = next[fail[now]][i];
				else {
					fail[next[now][i]] = next[fail[now]][i];
					q.push(next[now][i]);
				}
		}
	}
}tmp;
int dp[21000][maxn];
int hash[maxn];
int h[maxn];
char s[maxn];
int n;

void init(void)
{
	memset(dp, -1, sizeof dp);
	memset(hash, 0, sizeof hash);
	memset(h, 0, sizeof h);
}
void read(void)
{
	int i;
	tmp.init();
	for(i = 0; i < n; i++) {
		scanf("%s", tmp.s);
		tmp.insert();
	}
	tmp.build();
	scanf("%s", s);
	n = strlen(s);
	for(i = 0; i < n; i++)
		hash[tmp.hash[s[i] - 'A']]++;
	/*
	h[3] = 1;
	h[2] = hash[3]+1;
	h[1] = (hash[3]+1)*(hash[2]+1);
	h[0] = (hash[3]+1)*(hash[2]+1)*(hash[1]+1);
	*/
	h[0] = 1;
	h[1] = hash[0]+1;
	h[2] = (hash[0]+1)*(hash[1]+1);
	h[3] = (hash[0]+1)*(hash[1]+1)*(hash[2]+1);
}
inline int HASH(int a, int b, int c, int d)
{
	return a*h[0]+b*h[1]+c*h[2]+d*h[3];
}
void work(void)
{
	int i1, i2, i3, i4, i, j, k, ans, temp;
	dp[0][0] = 0;
	for(i1 = 0; i1 <= hash[0]; i1++)
		for(i2 = 0; i2 <= hash[1]; i2++)
			for(i3 = 0; i3 <= hash[2]; i3++)
				for(i4 = 0; i4 <= hash[3]; i4++){
					temp = HASH(i1, i2, i3, i4);
					for(j = 0; j < tmp.top; j++)
						if(~dp[temp][j])
							for(k = 0; k < 4; k++) {
								if(k == 0 && i1 == hash[0]) continue;
								if(k == 1 && i2 == hash[1]) continue;
								if(k == 2 && i3 == hash[2]) continue;
								if(k == 3 && i4 == hash[3]) continue;
								dp[temp+h[k]][tmp.next[j][k]] = max(dp[temp+h[k]][tmp.next[j][k]], dp[temp][j] + tmp.end[tmp.next[j][k]]);
							}
				}
	ans = 0;
	temp = HASH(hash[0], hash[1], hash[2], hash[3]);
	for(i = 0; i < tmp.top; i++)
		ans = max(ans, dp[temp][i]);
	printf("%d\n", ans);
}
int main(void)
{
	int _ = 0;
	while(scanf("%d", &n), n!=0) {
		init();
		read();
		printf("Case %d: ", ++_);
		work();
	}
	return 0;
}


变进制+2的情况

#include <iostream>
#include <sstream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <climits>
#define maxn 505
#define eps 1e-6
#define mod 10007
#define INF 99999999
#define lowbit(x) ((x)&(-(x)))
#define lson o<<1, L, mid
#define rson o<<1 | 1, mid+1, R
typedef long long LL;
using namespace std;

struct trie
{
	int next[maxn][4];
	int fail[maxn];
	int end[maxn];
	char s[maxn];
	queue<int> q;
	int hash[26];
	int top, now, root;
	
	int newnode(void)
	{
		end[top] = 0;
		fail[top] = -1;
		for(int i = 0; i < 4; i++)
			next[top][i] = -1;
		return top++;
	}
	void init(void)
	{
		top = 0;
		root = newnode();
		memset(hash, 0 ,sizeof hash);
		hash['A' - 'A'] = 0;
		hash['C' - 'A'] = 1;
		hash['G' - 'A'] = 2;
		hash['T' - 'A'] = 3;
	}
	void insert(void)
	{
		int len = strlen(s), i, k;
		now = root;
		for(i = 0; i < len; i++) {
			k = hash[s[i] - 'A'];
			if(next[now][k] == -1)
				next[now][k] = newnode();
			now = next[now][k];
		}
		end[now] += 1;
	}
	void build(void)
	{
		now = root;
		for(int i = 0; i < 4; i++)
			if(next[now][i] == -1)
				next[now][i] = root;
			else {
				fail[next[now][i]] =root;
				q.push(next[now][i]);
			}
		while(!q.empty()) {
			now = q.front();
			q.pop();
			if(end[fail[now]]) end[now] += end[fail[now]];
			for(int i = 0; i < 4; i++)
				if(next[now][i] == -1)
					next[now][i] = next[fail[now]][i];
				else {
					fail[next[now][i]] = next[fail[now]][i];
					q.push(next[now][i]);
				}
		}
	}
}tmp;
int dp[21000][maxn];
int hash[maxn];
int h[maxn];
char s[maxn];
int n;

void init(void)
{
	memset(dp, -1, sizeof dp);
	memset(hash, 0, sizeof hash);
	memset(h, 0, sizeof h);
}
void read(void)
{
	int i;
	tmp.init();
	for(i = 0; i < n; i++) {
		scanf("%s", tmp.s);
		tmp.insert();
	}
	tmp.build();
	scanf("%s", s);
	n = strlen(s);
	for(i = 0; i < n; i++)
		hash[tmp.hash[s[i] - 'A']]++;
	/*
	h[3] = 1;
	h[2] = hash[3]+1;
	h[1] = (hash[3]+1)*(hash[2]+1);
	h[0] = (hash[3]+1)*(hash[2]+1)*(hash[1]+1);
	*/
	h[0] = 1;
	h[1] = hash[0]+2;
	h[2] = (hash[0]+2)*(hash[1]+2);
	h[3] = (hash[0]+2)*(hash[1]+2)*(hash[2]+2);
}
inline int HASH(int a, int b, int c, int d)
{
	return a*h[0]+b*h[1]+c*h[2]+d*h[3];
}
void work(void)
{
	int i1, i2, i3, i4, i, j, k, ans, temp;
	dp[0][0] = 0;
	for(i1 = 0; i1 <= hash[0]; i1++)
		for(i2 = 0; i2 <= hash[1]; i2++)
			for(i3 = 0; i3 <= hash[2]; i3++)
				for(i4 = 0; i4 <= hash[3]; i4++){
					temp = HASH(i1, i2, i3, i4);
					for(j = 0; j < tmp.top; j++)
						if(~dp[temp][j])
							for(k = 0; k < 4; k++)
								dp[temp+h[k]][tmp.next[j][k]] = max(dp[temp+h[k]][tmp.next[j][k]], dp[temp][j] + tmp.end[tmp.next[j][k]]);
				}
	ans = 0;
	temp = HASH(hash[0], hash[1], hash[2], hash[3]);
	for(i = 0; i < tmp.top; i++)
		ans = max(ans, dp[temp][i]);
	printf("%d\n", ans);
}
int main(void)
{
	int _ = 0;
	while(scanf("%d", &n), n!=0) {
		init();
		read();
		printf("Case %d: ", ++_);
		work();
	}
	return 0;
}


你可能感兴趣的:(dp,HDU,AC自动机)