BCD Code zoj3494

  都是泪,看来以后要多做做数位DP题了,看了《浅谈数位类统计问题》感觉数位DP转化成树的思想真心不错,能简化思维的难度,可以很方便的处理一些细节问题(主要就是对于前导0的处理),用记忆化搜索实现起来比较方便,而且效率也比较高。


记忆化搜索:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <queue>
#include <algorithm>
#include <vector>
#include <cstring>
#include <stack>
#include <cctype>
#include <utility>   
#include <map>
#include <string>  
#include <climits> 
#include <set>
#include <string>    
#include <sstream>
#include <utility>   
#include <ctime>

using std::priority_queue;
using std::vector;
using std::swap;
using std::stack;
using std::sort;
using std::max;
using std::min;
using std::pair;
using std::map;
using std::string;
using std::cin;
using std::cout;
using std::set;
using std::queue;
using std::string;
using std::istringstream;
using std::make_pair;
using std::getline;
using std::greater;
using std::endl;
using std::multimap;
using std::deque;

typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PAIR;
typedef multimap<int, int> MMAP;

const int MAXN(2010);
const int SIGMA_SIZE(2);
const int MAXM(110);
const int MAXE(300010);
const int MAXH(18);
const int INFI((INT_MAX-1) >> 1);
const int MOD(1000000009);
const ULL LIM(1000000000000000ull);

struct AC
{
	int ch[MAXN][SIGMA_SIZE];
	int f[MAXN];
	bool val[MAXN];
	int size;

	void init()
	{
		memset(ch[0], 0, sizeof(ch[0]));
		f[0] = 0;
		val[0] = false;
		size = 1;
	}

	inline int idx(char temp)
	{
		return temp-'0';
	}

	void insert(char *S)
	{
		int u = 0, id;
		for(; *S; ++S)
		{
			id = idx(*S);
			if(!ch[u][id])
			{
				memset(ch[size], 0, sizeof(ch[size]));
				val[size] = 0;
				ch[u][id] = size++;
			}
			u = ch[u][id];
		}
		val[u] = true;
	}
	int que[MAXN];
	int front, back;
	void construct()
	{
		front = back = 0;
		int cur, u;
		for(int i = 0; i < SIGMA_SIZE; ++i)
		{
			u = ch[0][i];
			if(u)
			{
				que[back++] = u;
				f[u] = 0;
			}
		}
		while(front < back)
		{
			cur = que[front++];
			for(int i = 0; i < SIGMA_SIZE; ++i)
			{
				u = ch[cur][i];
				if(u)
				{
					que[back++] = u;
					f[u] = ch[f[cur]][i];
					val[u] |= val[f[u]];
				}
				else
					ch[cur][i] = ch[f[cur]][i];
			}
		}
	}
};

AC ac;

int hash[10][4] = {
					{0, 0, 0, 0}, 
					{0, 0, 0, 1}, 
					{0, 0, 1, 0}, 
					{0, 0, 1, 1}, 
					{0, 1, 0, 0}, 
					{0, 1, 0, 1}, 
					{0, 1, 1, 0}, 
					{0, 1, 1, 1}, 
					{1, 0, 0, 0}, 
					{1, 0, 0, 1}
};


bool move(int &u, int id)
{
	for(int i = 0; i < 4; ++i)
	{
		u = ac.ch[u][hash[id][i]];
		if(ac.val[u])
			return false;
	}
	return true;
}

int table[MAXN][210];
int digit[210];

int dfs(int u, int len, bool bound, bool zero)
{
	if(len == 0)
		return 1;
	if(table[u][len] != -1 && !bound && !zero)
		return table[u][len];
	int ans = 0, tu;
	int up = bound? digit[len]: 9;
	if(zero && len > 1)
		ans = (ans+dfs(u, len-1, bound && up == 0, true))%MOD;
	else
	{
		tu = u;
		if(move(tu, 0))
			ans = (ans+dfs(tu, len-1, bound && up == 0, false))%MOD;
	}
	for(int i = 1; i <= up; ++i)
	{
		tu = u;
		if(move(tu, i))
			ans = (ans+dfs(tu, len-1, bound && i == up, false))%MOD;
	}
	if(!bound && !zero)
		table[u][len] = ans;
	return ans;
}

char num1[210], num2[210];

void solve()
{
	memset(table, -1, sizeof(table));
	int len = strlen(num2);
	for(int i = len; i >= 1; --i)
		digit[i] = num2[len-i]-'0';
	int tans = dfs(0, len, true, true);
	len = strlen(num1);
	for(int i = len; i >= 1; --i)
		digit[i] = num1[len-i]-'0';
	tans = (tans-dfs(0, len, true, true)+MOD)%MOD;
	printf("%d\n", tans);
}

int main()
{
	int TC;
	scanf("%d", &TC);
	while(TC--)
	{
		int n;
		ac.init();
		scanf("%d", &n);
		for(int i = 0; i < n; ++i)
		{
			scanf("%s", num1);
			ac.insert(num1);
		}
		ac.construct();
		scanf("%s%s", num1, num2);
		for(int i = strlen(num1)-1; i >= 0; --i)  //左界-1
			if(num1[i] == '0')
				num1[i] = '9';
			else
			{
				--num1[i];
				break;
			}
		solve();
	}
	return 0;
}



//预处理,按位统计


#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <queue>
#include <algorithm>
#include <vector>
#include <cstring>
#include <stack>
#include <cctype>
#include <utility>   
#include <map>
#include <string>  
#include <climits> 
#include <set>
#include <string>    
#include <sstream>
#include <utility>   
#include <ctime>

using std::priority_queue;
using std::vector;
using std::swap;
using std::stack;
using std::sort;
using std::max;
using std::min;
using std::pair;
using std::map;
using std::string;
using std::cin;
using std::cout;
using std::set;
using std::queue;
using std::string;
using std::istringstream;
using std::make_pair;
using std::getline;
using std::greater;
using std::endl;
using std::multimap;
using std::deque;

typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PAIR;
typedef multimap<int, int> MMAP;

const int MAXN(2010);
const int SIGMA_SIZE(2);
const int MAXM(110);
const int MAXE(300010);
const int MAXH(18);
const int INFI((INT_MAX-1) >> 1);
const int MOD(1000000009);
const ULL LIM(1000000000000000ull);

struct AC
{
	int ch[MAXN][SIGMA_SIZE];
	int f[MAXN];
	bool val[MAXN];
	int size;

	void init()
	{
		memset(ch[0], 0, sizeof(ch[0]));
		f[0] = 0;
		val[0] = false;
		size = 1;
	}

	inline int idx(char temp)
	{
		return temp-'0';
	}

	void insert(char *S)
	{
		int u = 0, id;
		for(; *S; ++S)
		{
			id = idx(*S);
			if(!ch[u][id])
			{
				memset(ch[size], 0, sizeof(ch[size]));
				val[size] = 0;
				ch[u][id] = size++;
			}
			u = ch[u][id];
		}
		val[u] = true;
	}
	int que[MAXN];
	int front, back;
	void construct()
	{
		front = back = 0;
		int cur, u;
		for(int i = 0; i < SIGMA_SIZE; ++i)
		{
			u = ch[0][i];
			if(u)
			{
				que[back++] = u;
				f[u] = 0;
			}
		}
		while(front < back)
		{
			cur = que[front++];
			for(int i = 0; i < SIGMA_SIZE; ++i)
			{
				u = ch[cur][i];
				if(u)
				{
					que[back++] = u;
					f[u] = ch[f[cur]][i];
					val[u] |= val[f[u]];
				}
				else
					ch[cur][i] = ch[f[cur]][i];
			}
		}
	}
};

AC ac;

int table[MAXN][210]; //table[i][j]表示在trie图中点i剩下的数字个数为j能得到的合法数字个数
int table2[210];   //对于前导为0的情况需要特殊处理
int hash[10][4] = {
					{0, 0, 0, 0}, 
					{0, 0, 0, 1}, 
					{0, 0, 1, 0}, 
					{0, 0, 1, 1}, 
					{0, 1, 0, 0}, 
					{0, 1, 0, 1}, 
					{0, 1, 1, 0}, 
					{0, 1, 1, 1}, 
					{1, 0, 0, 0}, 
					{1, 0, 0, 1}
};

int mlen;

bool move(int &u, int id)
{
	for(int i = 0; i < 4; ++i)
	{
		u = ac.ch[u][hash[id][i]];
		if(ac.val[u])
			return false;
	}
	return true;
}

int dfs(int u, int left)
{
	if(table[u][left] != -1)
		return table[u][left];
	int &cur = table[u][left];
	int tu;
	cur = 0;
	for(int i = 0; i <= 9; ++i)
	{
		tu = u;
		if(move(tu, i))
			cur = (cur+dfs(tu, left-1))%MOD;
	}
	return cur;
}

void pre()
{
	memset(table, -1, sizeof(table[0])*ac.size);
	for(int i = 0; i < ac.size; ++i)
		table[i][0] = 1;
	for(int i = 1; i <= mlen; ++i)
		for(int j = 0; j < ac.size; ++j)
			dfs(j, i);
	int u = 0;
	table2[1] = move(u, 0)? 1: 0;
	for(int i = 2; i <= mlen; ++i)
	{
		table2[i] = table2[i-1];
		for(int j = 1; j <= 9; ++j)
		{
			 u = 0;
			 if(move(u, j))
				 table2[i] = (table2[i]+table[u][i-2])%MOD;
		}
	}
}

int fun(char *S)
{
	int ret = 0, u, tu, id;
	int i = 0;
	while(i < mlen && S[i] == '0')
		++i;
	if(i == mlen)
		return table2[1];
	id = S[i]-'0';
	bool flag(true);
	ret = (ret+table2[mlen-i])%MOD;
	for(int j = 1; j < id; ++j)
	{
		u = 0;
		if(move(u, j))
			ret = (ret+table[u][mlen-i-1])%MOD;
	}
	u = 0;
	while(i < mlen-1)
	{
		flag = move(u, id);
		if(!flag)
			break;
		id = S[i+1]-'0';
		for(int j = 0; j < id; ++j)
		{
			tu = u;
			if(move(tu, j))
				ret = (ret+table[tu][mlen-i-2])%MOD;
		}
		++i;
	}
	if(flag)
	{
		flag = move(u, S[mlen-1]-'0');
		if(flag)
			ret = (ret+1)%MOD;
	}
	return ret;
}

char num1[210], num2[210];

void solve()
{
	printf("%d\n", (fun(num2)-fun(num1)+MOD)%MOD);
}

int main()
{
	int TC;
	scanf("%d", &TC);
	while(TC--)
	{
		int n;
		ac.init();
		scanf("%d", &n);
		for(int i = 0; i < n; ++i)
		{
			scanf("%s", num1);
			ac.insert(num1);
		}
		ac.construct();
		scanf("%s%s", num1, num2);
		mlen = strlen(num2);
		int len = strlen(num1);
		for(int i = len-1; i >= 0; --i)  //左界-1
			if(num1[i] == '0')
				num1[i] = '9';
			else
			{
				--num1[i];
				break;
			}
		int temp = mlen-len;
		for(int i = len; i >= 0; --i) //为了统一处理,添加前导0
			num1[i+temp] = num1[i];
		for(int i = 0; i < temp; ++i)
			num1[i] = '0';
		pre();
		solve();
	}
	return 0;
}






你可能感兴趣的:(BCD Code zoj3494)