给你 x 0 , x 1 , a , b , x i = a x i − 1 + b x i − 2 x_0,x_1,a,b, x_i=ax_{i-1}+bx_{i-2} x0,x1,a,b,xi=axi−1+bxi−2让你求出 x n x_n xn
典型的矩阵快速幂,但是n的范围太大,所以得快速幂得用十进制快速幂
#include
using namespace std;
#define ll long long
const int maxn = 1e6 + 5;
ll mod;
struct Matrix{
ll mat[2][2];
Matrix() {memset(mat, 0, sizeof(mat));};
void init() {
mat[0][0] = mat[1][1] = 1;
}
void init(ll a, ll b) {
mat[0][0] = 0; mat[0][1] = b;
mat[1][0] = 1; mat[1][1] = a;
}
void operator = (Matrix x) {
for (int i = 0; i <= 1; i ++)
for (int j = 0; j <= 1; j ++)
mat[i][j] = x.mat[i][j];
}
};
void Print(Matrix x) {
for (int i = 0; i <= 1; i ++) {
for (int j = 0; j <= 1; j ++)
cout << x.mat[i][j] << " ";
cout << endl;
}
}
Matrix operator * (Matrix x, Matrix y) {
Matrix t;
for (int i = 0; i <= 1; i ++)
for (int j = 0; j <= 1; j ++)
for (int k = 0; k <= 1; k ++)
t.mat[i][j] = (t.mat[i][j] + x.mat[i][k] * y.mat[k][j]) % mod;
return t;
}
Matrix Ksm(Matrix x, ll b) {
//cout << b << endl;
Matrix t; t.init();
while(b) {
if(b & 1) t = t * x;
x = x * x;
b >>= 1;
}
//Print(t);
return t;
}
int main() {
ll x0, x1, a, b;
scanf("%lld %lld %lld %lld", &x0, &x1, &a, &b);
char s[maxn];
scanf("%s%lld", s, &mod);
int len = strlen(s);
reverse(s, s+len);
Matrix t, ans; t.init(a, b);
ans.mat[0][0] = x0; ans.mat[0][1] = x1;
Matrix res;
res.init();
for (int i = 0; i < len; i ++) {
res = res * Ksm(t, s[i]-'0');
t = Ksm(t, 10);
// Print(res);
// Print(t);
}
ans = ans * res;
printf("%lld\n", ans.mat[0][0]);
return 0;
}
有这么一个递推式 x i = ( a ⋅ x i − 1 + b ) m o d    p x_i=(a\cdot x_{i-1}+b)\mod p xi=(a⋅xi−1+b)modp,让你求 v v v在 [ 1 , n − 1 ] [1,n-1] [1,n−1]中第一次出现的位置
因为递推式模 p p p,所以 x x x的循环节一定小于 p p p,
而 x x x又是这种形式 x n = a ( a ( a ( a x + b ) + b ) + b ) + b x_n=a(a(a(ax+b)+b)+b)+b xn=a(a(a(ax+b)+b)+b)+b
所以我们的任务就变成 A m ≡ v m o d    p A^{m} \equiv v\mod p Am≡vmodp求最小的 m m m
A 1 = x , A 2 = ( a x + b ) , A 3 = a ( a x + b ) + b , A 4 = a ( a ( a x + b ) + b ) + b ) + b {A^1=x,A^2=(ax+b),A^3=a(ax+b)+b,A^4=a(a(ax+b)+b)+b)+b} A1=x,A2=(ax+b),A3=a(ax+b)+b,A4=a(a(ax+b)+b)+b)+b
而 A m ≡ v m o d    p A^m\equiv v\mod p Am≡vmodp明显可以用BSGS
但是BSGS的一个使用条件能不能求出 A − i ∗ S A^{-i*S} A−i∗S
但是我们怎么求出 A − i ∗ S A^{-i*S} A−i∗S呢
正常的加是乘a加b,那么除就是除a减 b a \frac{b}{a} ab
举个例子:
x 0 = x , x 1 = a x + b , x 2 = a ( a x + b ) + b , x 3 = a ( a ( a x + b ) + b ) + b x_0=x,x_1=ax+b,x_2=a(ax+b)+b,x_3=a(a(ax+b)+b)+b x0=x,x1=ax+b,x2=a(ax+b)+b,x3=a(a(ax+b)+b)+b
我们从 x 3 x_3 x3降到 x 1 x_1 x1, x 3 x_3 x3先除 a a a再减去 b a \frac{b}{a} ab变成 x 2 x_2 x2,然后再除 a a a减去 b a \frac{b}{a} ab变成 x 1 x_1 x1
那我们从 A 2 S + j A^{2S+j} A2S+j降到 A S + j A^{S+j} AS+j只需要进行 S S S次操作即可
这样我们就可以用BSGS了
跟BSGS的步骤差不多,我们可以把式子化成 A i ∗ S + j ≡ v m o d    p A^{i*S+j}\equiv v\mod p Ai∗S+j≡vmodp
我们可以预处理出来 A S A^S AS ,然后遍历找到一个 A j ≡ v ∗ A − i ∗ S m o d    p A^j\equiv v*A^{-i*S}\mod p Aj≡v∗A−i∗Smodp
也就是说在这个式子中 A − i ∗ S A^{-i*S} A−i∗S不是一个值,而是一种操作,把 v v v所代表的次数降下 S S S
x 0 = 1 , x 1 = 2 ∗ 1 + 1 , x 2 = , x 4 = 15 , x 5 = 31 x_0=1,x_1=2*1+1,x_2=,x_4=15,x_5=31 x0=1,x1=2∗1+1,x2=,x4=15,x5=31
因为我们已经预处理了一个 A S A^S AS,那么在我们遍历 i i i的过程中每次降下一个 S S S,知道找到或者找不到
用Hash存一下 A j A^j Aj
#include
using namespace std;
#define ll long long
typedef pair<int, int>pis;
const int limit = 1e6;
pis d[limit+6];
int vals[limit+6], pos[limit+6];
int Ksm(ll a, int b, int p) {
ll res = 1;
while(b) {
if(b & 1) res = res * a % p;
a = a * a % p;
b >>= 1;
}
return res;
}
int inv(int a, int p) { return Ksm(a, p-2, p); }
void solve() {
ll n, x0, a, b, p; int Q;
scanf("%lld %lld %lld %lld %lld %d", &n, &x0, &a, &b, &p, &Q);
if(!a) {
while(Q --) {
int v; scanf("%d", &v);
if(v == x0) printf("0\n");
else if(v == b) printf("1\n");
else printf("-1\n");
}
return ;
}
d[0] = {x0, 0};
for (int i = 1; i <= limit; i ++) {
int val = (a*d[i-1].first+b) % p;
d[i] = {val, i};
}
sort(d, d+limit+1);
int cnt = 0;
for (int i = 0; i <= limit; i ++) {
vals[cnt] = d[i].first; pos[cnt++] = d[i].second;
while(d[i].first == d[i+1].first && i+1 <= limit) i++;
}
int inv_a = inv(a, p);
int inv_b = (p-b) % p * inv_a % p;
ll aa = 1, bb = 0;
for (int i = 0; i <= limit; i ++) {
aa = aa * inv_a % p;
bb = (bb * inv_a + inv_b) % p;
}
while(Q --) {
int v; scanf("%d", &v);
int it = lower_bound(vals, vals+cnt, v) - vals;
if(it < cnt && vals[it] == v) {
if(pos[it] < n) printf("%d\n", pos[it]);
else printf("-1\n");
continue;
}
int m = p/(limit+1) + 3, flag = 0;
for (int i = 1; i <= m; i ++) {
v = (aa * v + bb) % p;
it = lower_bound(vals, vals+cnt, v) - vals;
if(it<cnt && vals[it] == v) {
flag = 1;
int res = i*(limit+1)+pos[it];
if(res>=n) res = -1;
printf("%d\n", res);
break;
}
}
if(!flag) printf("-1\n");
}
}
int main() {
int T;
scanf("%d", &T);
while(T --) solve();
return 0;
}