首先暴力模拟这样的一个插入过程,不难发现每次就是找到v∈[x,y]的出现时间的最小的,然后走过去,区间变为[x,v-1]或[v+1,y],一直到叶子节点。
先设d=gcd(b,m)
显然的结论是,2*m/d轮以后,每次插入只会使那个点的深度加一。
之所以不是m/d轮,是因为比如第x轮加了一个东西,剩下的可能加到它的子树中,第x+m/d轮时,就应是它第x轮的点的右子树的最左节点的深度+1。
如果我们能快速知道第x(x<=m/d)轮的点深度,
假设第x轮的权值是v,我们只需要讨论一下v和v+d的加入时间即可计算m/d轮以后的答案。
考虑第x(x<=m/d)轮的点深度,开头那个模拟的过程,找到[x,y]里的最小的,缩小范围,继续找,记录经过的点,实际上可以分成log段等差数列。
对于区间[x,y],先找到了v1,再找到了v2,如果v2+(v2-v1)合法,那么下次找到的一定是这个。
这个利用反证法易得(也很显然),这样就类似于gcd的过程,所以是log段的。
那么现在的目标就是找到最小的x使 l < = ( b x + a ) < = r ( m o d m ) l<=(bx+a)<=r(mod~m) l<=(bx+a)<=r(mod m)
a是常数,相当于平移,可以去掉,也就是 l < = b x < = r ( m o d m ) l<=bx<=r(mod~m) l<=bx<=r(mod m)
一个暴力的做法:
二 分 x , 相 当 于 求 有 多 少 y ∗ b m o d m < r ( 0 < = y < = x ) 二分x,相当于求有多少y*b~mod~m<r(0<=y<=x) 二分x,相当于求有多少y∗b mod m<r(0<=y<=x)
这是一个经典的问题,可以用标准类欧解决:
∑ i = 0 n [ a i m o d c < y ] \sum_{i=0}^n[ai~mod~c<y] ∑i=0n[ai mod c<y]
= ∑ i = 0 n ⌊ a i + c c ⌋ − ⌊ a i + c − y c ⌋ =\sum_{i=0}^n{\lfloor {{ai+c} \over {c}}\rfloor}-{\lfloor {{ai+c-y} \over {c}}\rfloor} =∑i=0n⌊cai+c⌋−⌊cai+c−y⌋
这样总复杂度是 O ( n l o g 3 n ) O(n~log^3n) O(n log3n)的,TLE了。
实际上找到最小的x使 l < = ( b x + a ) < = r ( m o d m ) l<=(bx+a)<=r(mod~m) l<=(bx+a)<=r(mod m)可以直接类欧实现做到一个log。
设 g ( m , d , l , r ) g(m,d,l,r) g(m,d,l,r)表示最小的 x x x使 l < = d x m o d m < = r l<=dx~mod~m<=r l<=dx mod m<=r
总复杂度 O ( n l o g 2 n ) O(n~log^2~n) O(n log2 n)
关于实现的一点小细节,其实不用判断v和v+d到底是谁早,可以求v的左偏父亲和v+d的右偏父亲的个数和即可。
Code:
#include
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i < B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
using namespace std;
int T;
int a, b, m; ll n;
ll gcd(ll x, ll y) { return !y ? x : gcd(y, x % y);}
int g(int m, int d, int l, int r) {
ll x = l / d;
if(l % d) x ++;
if(x * d <= r) return x;
if(d > m - d) return g(m, m - d, m - r, m - l);
int k = g(d, (d - m % d) % d, l % d, r % d);
x = (ll) k * m + l;
ll y = x / d; if(x % d) y ++;
return y;
}
int calc(int m, int d, int l, int r) {
int gd = gcd(m, d);
if(l == 0) return 0;
if((l - 1) / gd >= r / gd) return m + 1;
return g(m, d, l, r);
}
int calc2(int l, int r) {
if(l - a >= 0)
return calc(m, b, l - a, r - a);
if(r - a < 0)
return calc(m, b, l - a + m, r - a + m);
return min(calc(m, b, l - a + m, m - 1), calc(m, b, 0, r - a));
}
int calc3(int x, int l, int r, int z) {
int st = ((ll) calc2(l, r) * b + a) % m, ans = 0;
while(st != x) {
if(z) l = st + 1; else
r = st - 1;
if(l > r) return ans;
int t = calc2(l, r);
if(t > m) return ans;
int nt = ((ll) t * b + a) % m;
if(z) {
int y = (r - st) / (nt - st);
ans += y;
st += y * (nt - st);
} else {
int y = (st - l) / (st - nt);
ans += y;
st -= y * (st - nt);
}
}
return ans;
}
int gg(int x, int F) {
int ans = 0;
x = ((ll) b * x + a) % m;
int s = calc3(x, 0, x, 1);
if(F == 1) {
int d = gcd(m, b);
if(x + d < m) s += calc3(x + d, x + d, m - 1, 0) + 1;
} else {
s += calc3(x, x, m - 1, 0);
}
return s;
}
int main() {
freopen("fuwa.in", "r", stdin);
freopen("fuwa.out", "w", stdout);
scanf("%d", &T);
fo(ii, 1, T) {
scanf("%d %d %d %lld", &a, &b, &m, &n);
b %= m;
a = (a + b) % m; n --;
int d = gcd(b, m);
int m2 = m / d;
pp("%lld\n", gg(n % m2, n >= m2) + (n / m2));
}
}