Comet OJ - Contest #14(set区间操作 + 树状数组)

Comet OJ - Contest #14

Comet OJ - Contest #14(set区间操作 + 树状数组)_第1张图片

做法

这题是一个很骚的做法。

因为每次是把整个区间覆盖为某个数字,所以可以看作是把一段区间内的很多段数字合并成一个的过程。

我们考虑用 s e t set set去维护这个过程, s e t set set里面保存四元组 ( l , r , x , i d ) (l,r,x,id) (l,r,x,id),表示区间 [ l , r ] [l,r] [l,r]都是 x x x且是在第 i d id id个操作之后改变的。 s e t set set r r r为关键字排序。同时,用一个树状数组表示执行了前 i i i个操作之后的所有数字总和,每次维护每个操作之后总和的变化。

显然,在一开始的时候,整个 s e t set set里面只有 ( 1 , m , 0 , 0 ) (1,m,0,0) (1,m,0,0),即全部是0。对于一个操作 ( l , r , x ) (l,r,x) (l,r,x),首先找到包含 l l l的那个区间,假设为 [ L , R ] [L,R] [L,R],把它分为 [ L , l − 1 ] [L,l-1] [L,l1] [ l , R ] [l,R] [l,R]两个区间,对 r r r也做相同的操作。这样做了之后,就会出现以 l l l为左端点的区间和以 r r r为右端点的区间,这样接下来我们就可以确定 s e t set set内这两个区间之间的所有区间。而这些区间就是我们需要去掉的,直接遍历一遍,在树状数组里面把每个区间对应的 i d id id及其之后的所有总和减去 ( r − l + 1 ) ∗ x (r-l+1)*x (rl+1)x。最后,再把当前区间对应的四元组 ( l , r , x , i d ) (l,r,x,id) (l,r,x,id)加入,对应的在树状数组的 i d id id及其之后都加上相应的和即可。

接下来我们来考虑一下复杂度。对于每个操作,我们最多会拆分左右两个区间,对应增加两个四元组,总共就是 2 n 2n 2n个四元组。而对于每个四元组,我们最多会在 s e t set set中插入和删除一次,每次插入和删除的复杂度是 O ( l o g N ) O(logN) O(logN),这里的复杂度就是 O ( N l o g N ) O(NlogN) O(NlogN)。然后,每个操作,我们会在set里面找其左右端点对应的位置,这里每次也是 O ( l o g N ) O(logN) O(logN)的,总时间复杂度不会边。最后就是树状数组的修改,对应每个四元组最多两次修改,因此整个过程下来复杂度还是 O ( N l o g N ) O(NlogN) O(NlogN)

代码

#include
#define fi first
#define se second
#define LL long long
#define pb push_back
#define lb lower_bound
#define INF 0x3f3f3f3f
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)

using namespace std;

const int N=1000010;

struct Operation
{
	int l,r,x,id;

	bool operator < (const Operation &a) const
	{
		return a.r>r;
	}

} op[N];

struct query{int l,r,id;} q[N];

set<Operation> st;
LL c[N],ans[N];
int n,m,Q;

inline void update(int x,LL y)
{
    if (x==0) return;
	for(int i=x;i<N;i+=i&-i)
		c[i]+=y;
}

inline LL getsum(int x)
{
	LL res=0;
	for(int i=x;i;i-=i&-i)
		res+=c[i];
	return res;
}

inline bool cmp(query a,query b)
{
	return a.r<b.r;
}

inline void split(set<Operation>::iterator it,int x)
{
    if (x<it->l||x>it->r) return;
	int l=it->l,r=it->r;
	int v=it->x,id=it->id;
	st.erase(it);
	st.insert({l,x,v,id});
	st.insert({x+1,r,v,id});
}

int main()
{
	sccc(n,m,Q);
	for(int i=1;i<=n;i++)
		sccc(op[i].l,op[i].r,op[i].x);
	for(int i=1;i<=Q;i++)
	{
		scc(q[i].l,q[i].r);
		q[i].id=i;
	}
	st.insert({1,m,0,0});
	sort(q+1,q+1+Q,cmp);
	for(int i=1;i<=Q;i++)
	{
		for(int j=q[i-1].r+1;j<=q[i].r;j++)
		{
			int l=op[j].l,r=op[j].r,x=op[j].x;
			auto L=st.lb({0,l-1,0,0});
			split(L,l-1);
			auto R=st.lb({0,r,0,0});
			split(R,r);
			L=st.lb({0,l,0,0});
			R=st.lb({0,r+1,0,0});
			while(L!=R)
			{
				auto cur=L; L++;
				update(cur->id,(LL)-(cur->r-cur->l+1)*cur->x);
				st.erase(cur);
			}
			update(j,(LL)(r-l+1)*x);
			st.insert({l,r,x,j});
		}
		ans[q[i].id]=getsum(q[i].r)-getsum(q[i].l-1);
	}
	for(int i=1;i<=Q;i++)
		printf("%lld\n",ans[i]);
	return 0;
}

你可能感兴趣的:(树状数组)