求长度为n,字符集大小为m的合法串个数,其中串合法当且仅当对于任意的前缀(包括空),把它接到串的末尾并删除,所得到的是一个回文串。(只要其中一个前缀合法就算合法)取模
(所有的除号都代表下取整)
记录f[i]代表长度为i,字符集大小为m,且最小周期为i的回文串个数,显然f[i]=m^((k+1)/2)-Σf[j](其中j是i的约数且不等于i)。那么答案为Σf[i]*i/2(i是n的约数且i是偶数)+Σf[i]*i(i是奇数且i是n的约数),可以这么理解:对于长度i为奇数的回文串,我们把它复制n/i次它仍然是个回文,然后把长度为i的字符循环左移任意次它依然是个合法的串;但是对于偶数的串要除以2,比如说abba,循环移位之后是abba,bbaa,baab,aabb,其中baab也是一个回文,所以要除以2
但是这题n的范围达到10^18,约数个数达到11W,本质不同的质因子达到15个。
#include
#include
#include
#include
#include
#include
using namespace std;
int a[25] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 77, 83, 89, 97};
long long n, m, MOD, num[1005], tim[1005], g[200005], leng, dp[65536], b[25], tot;
int bit[65536], LG[65536], f[200005], d[65536];
int a1[1000005], b1[1000005], c1[1000005], h[200005];
int len, i, j, k;
struct sb {
long long x;
int y;
};
inline void add(int &x, int y)
{
x += y;
if (x >= MOD) x -= MOD;
}
inline void dec(int &x, int y)
{
x -= y;
if (x < 0) x += MOD;
}
inline long long chen(long long x, long long y, long long z)
{
// cout << x << " " << y << endl;
if (!x || !y) return 0;
long long p = chen(x, y >> 1, z);
p = p + p;
if (p >= z) p -= z;
if (y & 1)
{
p += x;
if (p >= z) p -= z;
}
return p;
}
inline long long ksm(long long x, long long y, long long z)
{
long long b = 1;
while (y)
{
if (y & 1) b = chen(b, x, z);
x = chen(x, x, z);
y >>= 1;
}
return b;
}
inline int Ksm(int x, long long y, int z)
{
int b = 1;
while (y)
{
if (y & 1) b = (long long)b * x % z;
x = (long long)x * x % z;
y >>= 1;
}
return b;
}
inline bool isprim(long long n)
{
if (n < 100)
{
for(int i = 0; i < 25; i ++)
if (a[i] == n) return 1;
return 0;
}
long long nn = n - 1; int st = 0;
while (!(nn & 1)) nn >>= 1, st ++;
for(int i = 0; i < 6; i ++)
{
long long x = ksm(a[i], nn, n);
if (x == 1) continue;
bool ok = 0;
for(int j = 1; j <= st; j ++)
{
x = chen(x, x, n);
if (x == 1 || x == n - 1) ok = 1;
}
if (!ok) return 0;
if (x != 1) return 0;
}
return 1;
}
inline long long gcd(long long a, long long b)
{
long long c = a % b;
while (c)
{
a = b;
b = c;
c = a % b;
}
return b;
}
inline long long getdiv(long long n, long long c)
{
long long x = 1LL * rand() * rand() % n, y = x, d = 1;
while (d == 1)
{
x = (chen(x, x, n) + c) % n;
y = (chen(y, y, n) + c) % n;
y = (chen(y, y, n) + c) % n;
long long delta = (x > y) ? x - y : y - x;
if (delta) d = gcd(n, delta);
}
return d;
}
inline void force(long long n)
{
if (n <= 10000)
{
int N = n;
for(int i = 2; i * i <= N; i ++)
if (n % i == 0)
{
num[++len] = i;
while (n % i == 0) n /= i;
}
if (n != 1) num[++len] = n;
return;
}
if (isprim(n)) {num[++len] = n; return;}
else {
long long d = n;
while (d == n) d = getdiv(n, 1LL * rand() * rand() % (n - 1) + 1);
force(d);
force(n / d);
}
}
inline void dfs(int now, long long k)
{
if (now > len) {g[++leng] = k; return;}
long long res = 1;
for(int i = 0; i <= tim[now]; i ++)
{
dfs(now + 1, k * res);
res = res * num[now];
}
}
int main()
{
srand(19260817);
LG[0] = -1;
for(i = 1; i < 65536; i ++)
{
bit[i] = bit[i >> 1] + (i & 1);
LG[i] = LG[i >> 1] + 1;
}
int T;
cin >> T;
while (T --)
{
cin >> n >> m >> MOD;
m %= MOD;
len = 0;
force(n);
sort(num + 1, num + 1 + len);
len = unique(num + 1, num + 1 + len) - 1 - num;
for(i = 1; i <= len; i ++)
{
tim[i] = 0;
long long N = n;
while (N % num[i] == 0)
{
N /= num[i];
tim[i] ++;
}
}
leng = 0;
dfs(1, 1);
sort(g + 1, g + 1 + leng);
for(i = 1; i <= leng; i ++)
{
if (g[i] & 1) h[i] = g[i] % MOD;
else h[i] = (g[i] >> 1) % MOD;
f[i] = Ksm(m, g[i] + 1 >> 1, MOD);
}
j = leng;
int ans = 0;
memset(d, 0, sizeof(d));
for(i = 1; i <= leng; i ++)
{
while (j && n / g[i] != g[j]) j --;
if ((g[i] & 1) && !(g[j] & 1)) continue;
int opt = 0;
for(k = 1; k <= len; k ++)
if (g[j] % num[k] == 0) opt |= 1 << k - 1;
d[opt] = (d[opt] + (long long)f[i] * h[i]) % MOD;
}
ans = 0;
int all = 1 << len;
for(int s = 0; s < all; s ++)
{
if (!d[s]) continue;
int lenb = 0;
for(i = 1; i <= len; i ++)
if (s >> (i - 1) & 1) b[++lenb] = num[i] % MOD;
dp[0] = 1;
int all1 = 1 << lenb;
for(int opt = 1; opt < all1; opt ++)
{
int l = 1 << LG[opt];
dp[opt] = (long long)dp[opt ^ l] * b[LG[opt] + 1] % MOD;
}
tot = 0;
for(int opt = 0; opt < all1; opt ++)
if (bit[opt] & 1) tot -= dp[opt];
else tot += dp[opt];
tot %= MOD;
if (tot < 0) tot += MOD;
ans = (ans + (long long)tot * d[s]) % MOD;
}
cout << (ans + MOD) % MOD << endl;
}
}