bzoj 5330反回文串

求长度为n,字符集大小为m的合法串个数,其中串合法当且仅当对于任意的前缀(包括空),把它接到串的末尾并删除,所得到的是一个回文串。(只要其中一个前缀合法就算合法)取模

(所有的除号都代表下取整)

记录f[i]代表长度为i,字符集大小为m,且最小周期为i的回文串个数,显然f[i]=m^((k+1)/2)-Σf[j](其中j是i的约数且不等于i)。那么答案为Σf[i]*i/2(i是n的约数且i是偶数)+Σf[i]*i(i是奇数且i是n的约数),可以这么理解:对于长度i为奇数的回文串,我们把它复制n/i次它仍然是个回文,然后把长度为i的字符循环左移任意次它依然是个合法的串;但是对于偶数的串要除以2,比如说abba,循环移位之后是abba,bbaa,baab,aabb,其中baab也是一个回文,所以要除以2

但是这题n的范围达到10^18,约数个数达到11W,本质不同的质因子达到15个。

若i是奇数h[i]=i,否则h[i]=i/2bzoj 5330反回文串_第1张图片

#include
#include
#include
#include
#include
#include
using namespace std;
int a[25] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 77, 83, 89, 97};
long long n, m, MOD, num[1005], tim[1005], g[200005], leng, dp[65536], b[25], tot;
int bit[65536], LG[65536], f[200005], d[65536];
int a1[1000005], b1[1000005], c1[1000005], h[200005];
int len, i, j, k;
struct sb {
	long long x;
	int y;
};
inline void add(int &x, int y)
{
	x += y;
	if (x >= MOD) x -= MOD;
}
inline void dec(int &x, int y)
{
	x -= y;
	if (x < 0) x += MOD;
}
inline long long chen(long long x, long long y, long long z)
{
//	cout << x << " " << y << endl;
	if (!x || !y) return 0;
	long long p = chen(x, y >> 1, z);
	p = p + p;
	if (p >= z) p -= z;
	if (y & 1)
	{
		p += x;
		if (p >= z) p -= z;
	}
	return p;
}
inline long long ksm(long long x, long long y, long long z)
{
	long long b = 1;
	while (y)
	{
		if (y & 1) b = chen(b, x, z);
		x = chen(x, x, z);
		y >>= 1;
	}
	return b;
}
inline int Ksm(int x, long long y, int z)
{
	int b = 1;
	while (y)
	{
		if (y & 1) b = (long long)b * x % z;
		x = (long long)x * x % z;
		y >>= 1;
	}
	return b;
}
inline bool isprim(long long n)
{
	if (n < 100)
	{
		for(int i = 0; i < 25; i ++)
			if (a[i] == n) return 1;
		return 0;
	}
	long long nn = n - 1; int  st = 0;
	while (!(nn & 1)) nn >>= 1, st ++;
	for(int i = 0; i < 6; i ++)
	{
		long long x = ksm(a[i], nn, n);
		if (x == 1) continue;
		bool ok = 0;
		for(int j = 1; j <= st; j ++)
		{
			x = chen(x, x, n);
			if (x == 1 || x == n - 1) ok = 1;
		}
		if (!ok) return 0;
		if (x != 1) return 0;
	}
	return 1;
}
inline long long gcd(long long a, long long b)
{
	long long c = a % b;
	while (c)
	{
		a = b;
		b = c;
		c = a % b;
	}
	return b;
}
inline long long getdiv(long long n, long long c)
{
	long long x = 1LL * rand() * rand() % n, y = x, d = 1;
	while (d == 1)
	{
		x = (chen(x, x, n) + c) % n;
		y = (chen(y, y, n) + c) % n;
		y = (chen(y, y, n) + c) % n;
		long long delta = (x > y) ? x - y : y - x;
		if (delta) d = gcd(n, delta);
	}
	return d;
}
inline void force(long long n)
{
	if (n <= 10000)
	{
		int N = n;
		for(int i = 2; i * i <= N; i ++)
			if (n % i == 0)
			{
				num[++len] = i;
				while (n % i == 0) n /= i;
			}
		if (n != 1) num[++len] = n;
		return;
	}
	if (isprim(n)) {num[++len] = n; return;}
	else {
		long long d = n;
		while (d == n) d = getdiv(n, 1LL * rand() * rand() % (n - 1) + 1);
		force(d);
		force(n / d);
	}
}
inline void dfs(int now, long long k)
{
	if (now > len) {g[++leng] = k; return;}
	long long res = 1;
	for(int i = 0; i <= tim[now]; i ++)
	{
		dfs(now + 1, k * res);
		res = res * num[now];
	}
}
int main()
{
	srand(19260817);
	LG[0] = -1;
	for(i = 1; i < 65536; i ++)
	{
		bit[i] = bit[i >> 1] + (i & 1);
		LG[i] = LG[i >> 1] + 1;
	}
	int T;
	cin >> T;
	while (T --)
	{
		cin >> n >> m >> MOD;
		m %= MOD;
		len = 0;
		force(n);
		sort(num + 1, num + 1 + len);
		len = unique(num + 1, num + 1 + len) - 1 - num;
		for(i = 1; i <= len; i ++)
		{
			tim[i] = 0;
			long long N = n;
			while (N % num[i] == 0)
			{
				N /= num[i];
				tim[i] ++;
			}
		}
		leng = 0;
		dfs(1, 1);
		sort(g + 1, g + 1 + leng);
		for(i = 1; i <= leng; i ++)
		{
			if (g[i] & 1) h[i] = g[i] % MOD;
			else h[i] = (g[i] >> 1) % MOD;
			f[i] = Ksm(m, g[i] + 1 >> 1, MOD);
		}
		j = leng;
		int ans = 0;
		memset(d, 0, sizeof(d));
		for(i = 1; i <= leng; i ++)
		{
			while (j && n / g[i] != g[j]) j --;
			if ((g[i] & 1) && !(g[j] & 1)) continue;
			int opt = 0;
			for(k = 1; k <= len; k ++)
				if (g[j] % num[k] == 0) opt |= 1 << k - 1;
			d[opt] = (d[opt] + (long long)f[i] * h[i]) % MOD;
		}
		ans = 0;
		int all = 1 << len;
		for(int s = 0; s < all; s ++)
		{
			if (!d[s]) continue;
			int lenb = 0;
			for(i = 1; i <= len; i ++)
				if (s >> (i - 1) & 1) b[++lenb] = num[i] % MOD;
			dp[0] = 1;
			int all1 = 1 << lenb;
			for(int opt = 1; opt < all1; opt ++)
			{
				int l = 1 << LG[opt];
				dp[opt] = (long long)dp[opt ^ l] * b[LG[opt] + 1] % MOD;
			}
			tot = 0;
			for(int opt = 0; opt < all1; opt ++)
				if (bit[opt] & 1) tot -= dp[opt];
				else tot += dp[opt];
			tot %= MOD;
			if (tot < 0) tot += MOD;
			ans = (ans + (long long)tot * d[s]) % MOD;
		}
		cout << (ans + MOD) % MOD << endl;
	}
}

你可能感兴趣的:(简单计数)