题目:https://www.luogu.org/problemnew/show/P3168
萌新做的第一道主席树非模板题,emmm说实话搞得我头皮发麻,想了一个下午,最后还是去看了某神犇的题解,但是并没有看懂,似乎用了树套树(主席树套树状数组,复杂度O(nlog2n))。最后没办法,自己瞎搞居然搞出来一个O(nlog2n),瞬间感觉自己很厉害有木有,虽然我不会别人那些高端算法,但至少我的程序还比较优秀?
不多说废话,这道题给出一些区间和其对应的权值,要求某个点时包含这个点的区间中前k小的区间权值和。前k小?果断主席树。前k小的权值?简单,按照权值离散化后建树就行了。
然后,我三下五除二敲完板子,然后懵了:一般的主席树都是单点修改,这个区间修改让我很是难受。其实也不麻烦,参考树状数组的单点修改和区间修改的区别,我们差分就可以实现了:即:对于一个区间{l , r , val},我们把它拆成两个区间{l , n , val} , {r +1 , n , -val}。那么对应到权值线段树上,就是在l时将权值为val的点+1,在r+1是将权值为val的点-1。那么,这时候,我们可以发现问题只剩下怎么查询x时的线段树了。
很简单,将时间离散化,利用主席树支持访问历史操作的性质,预处理在每一个l ,r +1,树上的操作,最后查询时,只需要二分找到x在修改操作中的位置,就可以得到当时每种权值出现次数了,然后套板子。
一切都显得这样美好,但当你信心满满提交时,才发现事情没这么简单。原来,在查询时,即使你分到了最小的区间即l == r是依然可以存在区间中权值个数大于k的情况,但如果没特判直接返回,那就gg了。所以我们在最后 l == r时,将当前和值除以个数再乘k,就ok了。附代码(时间复杂度O(nlog2n)):
#include
using namespace std;
typedef long long ll;
const int maxn = 200010;
inline char get_char()
{
static char buf[100000] , *p1 = buf , *p2 = buf;
if (p1 == p2)
{
p2 = (p1 = buf) + fread(buf , 1 , 100000 , stdin);
if (p1 == p2)
{
return EOF;
}
}
return *p1++;
}
inline ll read()
{
ll res;
char ch;
while (!isdigit(ch = get_char()));
res = ch - '0';
while (isdigit(ch = get_char()))
{
res = res * 10 + ch - '0';
}
return res;
}
struct data
{
ll time , rank , val;
}T[maxn * 2];
bool cmp(data a , data b)
{
return a.time < b.time;
}
ll n , m , sz , tot , sz2;
ll l[maxn] , r[maxn] , rk[maxn] , rt[20 * maxn] , ls[maxn * 20] , rs[maxn * 20] , cnt[20 * maxn];
ll num[maxn] , tp[maxn] , sum[20 * maxn];
void build(ll& id , ll l , ll r)
{
id = ++tot;
if (l == r)
{
return;
}
ll mid = (l + r) >> 1;
build(ls[id] , l , mid);
build(rs[id] , mid + 1 , r);
}
void update(ll& id , ll l , ll r , ll last , ll pos , ll val)
{
id = ++tot;
ls[id] = ls[last];
rs[id] = rs[last];
cnt[id] = cnt[last] + val;
sum[id] = sum[last] + val * tp[pos];
if (l == r)
{
return;
}
ll mid = (l + r) >> 1;
if (pos <= mid)
{
update(ls[id] , l , mid , ls[last] , pos , val);
}
else
{
update(rs[id] , mid + 1 , r , rs[last] , pos , val);
}
}
ll query(ll s , ll l , ll r , ll k)
{
if (cnt[s] <= k)
{
return sum[s];
}
else if (l == r)
{
return sum[s] / cnt[s] * k;
}
ll mid = (l + r) >> 1;
if (cnt[ls[s]] >= k)
{
return query(ls[s] , l , mid , k);
}
else
{
return query(ls[s] , l , mid , cnt[ls[s]]) + query(rs[s] , mid + 1 , r , k - cnt[ls[s]]);
}
}
int main()
{
//freopen("data.in" , "r" , stdin);
n = read() , m = read();
ll top = 0;
for (ll i = 1; i <= n; i++)
{
l[i] = read() , r[i] = read() , num[i] = read();
tp[i] = num[i];
T[++top].time = l[i] , T[top].val = 1 , T[top].rank = i ,
T[++top].time = r[i] + 1 , T[top].val = -1 , T[top].rank = i;
}
sort(tp + 1 , tp + n + 1);
sz = unique(tp + 1 , tp + n + 1) - tp - 1;
build(rt[0] , 1 , sz);
for (ll i = 1; i <= n; i++)
{
rk[i] = lower_bound(tp + 1 , tp + sz + 1 , num[i]) - tp;
}
sort(T + 1 , T + top + 1 , cmp);
for (ll i = 1; i <= top; i++)
{
//cout << i << " " << T[i].time << " " << T[i].val << " " << rk[T[i].rank] << endl;
update(rt[i] , 1 , sz , rt[i - 1] , rk[T[i].rank] , T[i].val);
}
ll lastans = 1;
for (ll i = 1; i <= m; i++)
{
ll x = read();
ll a = read() , b = read() , c = read();
ll k = 1 + (a * lastans + b) % c;
ll le = 1 , ri = top;
ll pos = top + 1;
while (le <= ri)
{
ll mid = (le + ri) >> 1;
if (T[mid].time > x)
{
pos = mid;
ri = mid - 1;
}
else
{
le = mid + 1;
}
}
lastans = query(rt[pos - 1] , 1 , sz , k);
printf("%lld\n" , lastans);
}
return 0;
}