看了好久,终于把这题过了..这题的意思也很简单,就是支持3种操作,区间l,r的每个数乘上一个数,加上一个数,令为某个数.维护的信息则是某个区间这些数的1次方和,2次方和,3次方和.这题对于区间修改的线段树的理解很有帮助,以前都是套模板不知其所以然.实际上只要写好两个函数就可以了,一个pushDown,一个是pushUp, pushDown的实际就是将该点的标记往下传,传的时候要更新子结点的标记,以及更新子结点的信息,在有多个标记的时候就要考虑标记的先后顺序的影响,这个体现在标记下传的函数里.pushUp则是写好要维护的信息,这个相对就很简单.
对于本题来说,可以写3个标记或者2个标记.可以写add,set,mul3个标记,表示这个区间的相关信息,也可以只写add,mul(因为set(x)实际上是mul(0)+add(x)),不过这个没有太大影响.注意的是标记下传的时候,注意到假如我先 加x再乘y,实际上加的是xy,所以每次乘的时候要加加的标记也乘上,而加的时候对乘则没有影响.至于区间乘了x对于sum[i]实际上是乘了x^i,而加x的时候可以展开一下,如(a+x)^2=a^2+x^2+2ax,所以加一个x相当于是原来的sum[2]+2x*sum[1]+x^2*len(区间长度),这样去修改就好了.
写下自己的感受感觉会记忆深刻些.下面贴一记代码~
#include<iostream> #include<cstring> #include<string> #include<cstdio> #include<algorithm> #define maxn 100000 #define mod 10007 using namespace std; struct Node { int l,r,sum1,sum2,sum3,add,cov,mul; }N[maxn*4+200]; void build(int i,int L,int R) { N[i].l=L;N[i].r=R; N[i].cov=-1;N[i].add=0,N[i].mul=1; N[i].sum1=N[i].sum2=N[i].sum3=0; if(L==R){ return; } int M=(L+R)>>1; build(i<<1,L,M); build(i<<1|1,M+1,R); } void pushDown(int i) { if(N[i].cov!=-1){ // 默认cov为正,若不为正要修改cov的默认值 if(N[i].l!=N[i].r){ N[i<<1].add=N[i<<1|1].add=0; N[i<<1].mul=N[i<<1|1].mul=1; N[i<<1].cov=N[i<<1|1].cov=N[i].cov; N[i<<1].sum1=((N[i<<1].r-N[i<<1].l+1)%mod*N[i].cov)%mod; N[i<<1].sum2=(N[i<<1].sum1*N[i].cov)%mod; N[i<<1].sum3=(N[i<<1].sum2*N[i].cov)%mod; N[i<<1|1].sum1=(N[i<<1|1].r-N[i<<1|1].l+1)%mod*N[i].cov%mod; N[i<<1|1].sum2=N[i<<1|1].sum1%mod*N[i].cov%mod; N[i<<1|1].sum3=N[i<<1|1].sum2%mod*N[i].cov%mod; } N[i].cov=-1; } int m=N[i].mul%mod,t=N[i].add%mod; if(N[i].mul!=1||N[i].add>0){ if(N[i].l!=N[i].r){ N[i<<1].add=(N[i<<1].add*m%mod+t)%mod; N[i<<1|1].add=(N[i<<1|1].add*m%mod+t)%mod; N[i<<1].mul=N[i<<1].mul*m%mod; N[i<<1|1].mul=N[i<<1|1].mul*m%mod; int m1=m%mod,m2=m*m1%mod,m3=m*m2%mod; int s1=N[i<<1].sum1%mod,s2=N[i<<1].sum2%mod,s3=N[i<<1].sum3%mod; int t1=t%mod,t2=t*t1%mod,t3=t*t2%mod; int len=(N[i<<1].r-N[i<<1].l+1)%mod; N[i<<1].sum1=(m1*s1%mod+t1*len%mod)%mod; N[i<<1].sum2=(m2*s2%mod+2*m1%mod*s1%mod*t1%mod+t2*len%mod)%mod; N[i<<1].sum3=(m3*s3%mod+3*m2%mod*s2%mod*t%mod+3*m1%mod*s1%mod*t2%mod+t3*len%mod)%mod; s1=N[i<<1|1].sum1%mod,s2=N[i<<1|1].sum2%mod,s3=N[i<<1|1].sum3%mod; len=(N[i<<1|1].r-N[i<<1|1].l+1)%mod; N[i<<1|1].sum1=(m1*s1%mod+t1*len%mod)%mod; N[i<<1|1].sum2=(m2*s2%mod+2*m1%mod*s1%mod*t1%mod+t2*len%mod)%mod; N[i<<1|1].sum3=(m3*s3%mod+3*m2%mod*s2%mod*t1%mod+3*m1%mod*s1%mod*t2%mod+t3*len%mod)%mod; } N[i].mul=1;N[i].add=0; } } void pushUp(int i) { N[i].sum1=(N[i<<1].sum1%mod+N[i<<1|1].sum1%mod)%mod; N[i].sum2=(N[i<<1].sum2%mod+N[i<<1|1].sum2%mod)%mod; N[i].sum3=(N[i<<1].sum3%mod+N[i<<1|1].sum3%mod)%mod; } void add(int i,int L,int R,int val) { if(N[i].l==L&&N[i].r==R){ N[i].add=(N[i].add+val)%mod; int s1=N[i].sum1%mod,s2=N[i].sum2%mod,s3=N[i].sum3%mod; int t1=val%mod,t2=t1*val%mod,t3=val*t2%mod; int len=(N[i].r-N[i].l+1)%mod; N[i].sum1=(s1+t1*len)%mod; N[i].sum2=(s2+2*s1%mod*t1%mod+t2*len%mod)%mod; N[i].sum3=(s3+3*s2%mod*t1%mod+3*s1%mod*t2%mod+t3*len%mod)%mod; return; } pushDown(i); int M=(N[i].l+N[i].r)>>1; if(M>=R) add(i<<1,L,R,val); else if(M<L) add(i<<1|1,L,R,val); else{ add(i<<1,L,M,val); add(i<<1|1,M+1,R,val); } pushUp(i); } void set(int i,int L,int R,int val) { if(N[i].l==L&&N[i].r==R){ N[i].cov=val%mod;N[i].add=0;N[i].mul=1; N[i].sum1=(val%mod*((R-L+1)%mod))%mod; N[i].sum2=val%mod*N[i].sum1%mod; N[i].sum3=val%mod*N[i].sum2%mod; return; } pushDown(i); int M=(N[i].l+N[i].r)>>1; if(M>=R) set(i<<1,L,R,val); else if(M<L) set(i<<1|1,L,R,val); else{ set(i<<1,L,M,val); set(i<<1|1,M+1,R,val); } pushUp(i); } void mul(int i,int L,int R,int val) { if(N[i].l==L&&N[i].r==R){ N[i].add=N[i].add*val%mod; N[i].mul=N[i].mul*val%mod; int s1=N[i].sum1%mod,s2=N[i].sum2%mod,s3=N[i].sum3%mod; int m1=val%mod,m2=val*m1%mod,m3=val*m2%mod; N[i].sum1=(m1*s1)%mod; N[i].sum2=(m2*s2)%mod; N[i].sum3=(m3*s3)%mod; return; } pushDown(i); int M=(N[i].l+N[i].r)>>1; if(M>=R) mul(i<<1,L,R,val); else if(M<L) mul(i<<1|1,L,R,val); else{ mul(i<<1,L,M,val); mul(i<<1|1,M+1,R,val); } pushUp(i); } int query1(int i,int L,int R) { if(N[i].l==L&&N[i].r==R){ return N[i].sum1%mod; } pushDown(i); int M=(N[i].l+N[i].r)>>1; if(M>=R) return query1(i<<1,L,R)%mod; else if(M<L) return query1(i<<1|1,L,R)%mod; else{ return (query1(i<<1,L,M)+query1(i<<1|1,M+1,R))%mod; } pushUp(i); } int query2(int i,int L,int R) { if(N[i].l==L&&N[i].r==R){ return N[i].sum2%mod; } pushDown(i); int M=(N[i].l+N[i].r)>>1; if(M>=R) return query2(i<<1,L,R)%mod; else if(M<L) return query2(i<<1|1,L,R)%mod; else{ return (query2(i<<1,L,M)+query2(i<<1|1,M+1,R))%mod; } pushUp(i); } int query3(int i,int L,int R) { if(N[i].l==L&&N[i].r==R){ return N[i].sum3%mod; } pushDown(i); int M=(N[i].l+N[i].r)>>1; if(M>=R) return query3(i<<1,L,R)%mod; else if(M<L) return query3(i<<1|1,L,R)%mod; else{ return (query3(i<<1,L,M)+query3(i<<1|1,M+1,R))%mod; } pushUp(i); } int n,m; int o,l,r,p; int main() { while(cin>>n>>m&&(n||m)) { build(1,1,n); for(int i=0;i<m;i++) { scanf("%d%d%d%d",&o,&l,&r,&p); if(o==1) {add(1,l,r,p%mod);} else if(o==2) {mul(1,l,r,p%mod);} else if(o==3) {mul(1,l,r,0);add(1,l,r,p%mod);} else if(o==4){ if(p==1) printf("%d\n",query1(1,l,r)%mod); else if(p==2) printf("%d\n",query2(1,l,r)%mod); else if(p==3) printf("%d\n",query3(1,l,r)%mod); } } } return 0; }