线段树模板

其实从寒假就知道线段树这个东西了,但是嫌线段树写得长,一直用树状数组。
但最近发现线段树也很不错,于是就去做洛谷的两个线段树模板题(第一个模板曾经用树状数组A过)。
又因为去机房的时间较短较分散,于是就在自习课上对着线段树的那张图imagine线段树的原理并手动coding。

对就是这张丑图
先说三个宏定义

#define ls (i << 1)
#define rs ((i << 1) | 1)
#define mid ((n[i].l + n[i].r) >> 1)

一定要加括号啊!!!
首先我们要建树
怎么看都要二分递归嘛~~~
递归到叶子(l == r)就把初始数组的相应位置的值丢到线段树里
然后往上回溯,顺便改区间值

inline void built(int i, int l, int r)//过去式防与关键字冲突
{
	n[i].l = l;
	n[i].r = r;
	if(l == r)
	{
		n[i].sm = read();
		return ;
	}
	int md = (l + r) >> 1;
	built(ls, l, md);
	built(rs, md + 1, r);
	ud(i);
	return ; 
}

如果输入顺序就是初始数组的顺序就不用存初始数组了,直接建树时读入(反正先递归左边的区间)
ud就是updata,改区间值,简单粗暴

inline void ud(int i)
{
	n[i].sm = (n[ls].sm + n[rs].sm) % P;
	return ;
}

然后就是单点修改
从根跑到叶子
跑到哪改到哪

inline void cp(int i, int k, int x)// change point
{
    if(n[i].l == k && k == n[i].r)
    {
        n[i].sm += x;
        return ;
    }
    if(k <= mid) cp(ls, k, x);
    else         cp(rs, k, x);
    return ;
}

区间查询也很无脑
如果我要查的区间完全盖住了当前区间,就返回当前的区间值。如果只有一部分重合,就摆 ~ 动 ~ (dx:摆动大法好)。和左边有重合就往左摆,和右边有重合就往右摆。注意mid是在左区间里。

inline long long gs(int i, int l, int r)
{
    if(l <= n[i].l && r >= n[i].r)
        return n[i].sm;
    long long ans = 0;
    if(l <= mid) ans += gs(ls, l, r);//区间端点完全传下去!!!
    if(r > mid)  ans += gs(rs, l, r);
    return ans;
}

区间修改就麻烦些了
如果我们仍然从根开始,跑到哪改到哪,直到每个叶子,其实复杂度比暴力还高
于是我们就想优化:Lazy!
如果当前区间完全要被修改,就只改这个区间而不去改它的儿子们(就是懒得改儿子),我们下次在见到这个区间时(不管是修改时还是查询时见到它),再去改它的F1两个儿子
也就是打Lazy标记和把Lazy标记传给儿子(Push Down)
Lazy标记的意义:当前区间已被修改而它的儿子没有改

inline void pda(int i, int ln, int rn)
{
	n[ls].lza = (n[ls].lza + n[i].lza) % P;
	n[rs].lza = (n[rs].lza + n[i].lza) % P;
	//父亲的lazy给儿子
	n[ls].sm = (n[ls].sm + n[i].lza * ln) % P;
	n[rs].sm = (n[rs].sm + n[i].lza * rn) % P;
	n[i].lza = 0;
	//它的儿子被改了,它自己就不用Lazy了
	return ;
}
inline void csa(int i, int l, int r, int x)
{
	if(n[i].l >= l && n[i].r <= r)
	{
		n[i].sm = (n[i].sm + x * (n[i].r - n[i].l + 1)) % P;
		//区间的sum都被加了x,所以要乘区间长度!!!
		n[i].lza = (n[i].lza + x) % P;
		//lazy标记就只记这个区间要+x,加多少取决于区间长,与lazy标记无关(果然有够lazy)
		return ;
	}
	if(n[i].lza) pda(i, mid - n[i].l + 1, n[i].r - mid);
	if(mid < r)  csa(rs, l, r, x);
	if(l <= mid) csa(ls, l, r, x);
	ud(i);
	return ;
}

以上内容完全可以用树状数组实现嘛QWQ
但如果加法和乘法一块改区间时树状数组就炸了
然而我线段树也炸了调了1h
我们用加法lazy和乘法lazy共同维护线段树

  • 如果我先加一个数再乘一个数,由乘法分配律知,相当于先乘一个数再加一个数
    所以传标记时先传乘再传加
  • 区间每个数都乘一个数,由乘法结合律知,相当于整个区间都乘一个数,即a1 * x + a2 * x … = x * (a1 + a2 + …),即区间乘修改时不用乘区间长度
  • lazy乘的初值赋1!!!x * 0 = 0, x * 1 = x
  • 乘法会影响加法,因为先传了乘法。所以lazy乘改了lazy加也要相应的改

其实线段树最简单的地方就是可以复制粘贴,区间乘把区间加的复制过来一改就行

inline void pdm(int i, int ln, int rn)
{
	n[ls].lzm = (n[ls].lzm * n[i].lzm) % P;
	n[rs].lzm = (n[rs].lzm * n[i].lzm) % P;
	n[ls].lza = (n[ls].lza * n[i].lzm) % P;
	n[rs].lza = (n[rs].lza * n[i].lzm) % P;
	n[ls].sm = (n[ls].sm * n[i].lzm) % P;
	n[rs].sm = (n[rs].sm * n[i].lzm) % P;
	n[i].lzm = 1;
	return ;
}
inline void csm(int i, int l, int r, int x)
{
	if(n[i].l >= l && n[i].r <= r)
	{
		n[i].sm = (n[i].sm * x) % P;
		n[i].lzm = (n[i].lzm * x) % P;
		n[i].lza = (n[i].lza * x) % P;
		return ;
	}
	if(n[i].lzm != 1) pdm(i, mid - n[i].l + 1, n[i].r - mid);
	if(n[i].lza)      pda(i, mid - n[i].l + 1, n[i].r - mid);
	if(mid < r)  csm(rs, l, r, x);
	if(l <= mid) csm(ls, l, r, x);
	ud(i);
	return ;
}

长得好像有木有っ゚Д゚)っ
最后就是完整的模板了
真的好长

/*********
push down multiply first
then push down add
*********/
#include 
using namespace std;

inline long long read()
{
    long long n = 0,k = 1;
    char ch = getchar();
    while ((ch > '9' || ch < '0') && ch != '-')  ch = getchar();
    if(ch == '-') k = -1, ch = getchar();
    while (ch <= '9' && ch >= '0')
 	{
          n = n * 10 + ch - '0';
          ch = getchar();
    }
    return n * k;
}

inline void print(long long n)
{
    if(n < 0) {putchar('-'); n = -n;}
    if(n > 9) print(n / 10);
    putchar(n % 10 + '0');
    return ;
}

struct Node
{
	int l, r;
	long long sm, lza, lzm; //lazy_multiply
	Node()
	{
		lzm = 1;
	}
}n[500420];
long long N, M, P;

#define ls (i << 1)
#define rs ((i << 1) | 1)
#define mid ((n[i].l + n[i].r) >> 1)
inline void ud(int i)
{
	n[i].sm = (n[ls].sm + n[rs].sm) % P;
	return ;
}

inline void built(int i, int l, int r)
{
	n[i].l = l;
	n[i].r = r;
	if(l == r)
	{
		n[i].sm = read();
		return ;
	}
	int md = (l + r) >> 1;
	built(ls, l, md);
	built(rs, md + 1, r);
	ud(i);
	return ; 
}

inline void pda(int i, int ln, int rn)
{
	n[ls].lza = (n[ls].lza + n[i].lza) % P;
	n[rs].lza = (n[rs].lza + n[i].lza) % P;
	n[ls].sm = (n[ls].sm + n[i].lza * ln) % P;
	n[rs].sm = (n[rs].sm + n[i].lza * rn) % P;
	n[i].lza = 0;
	return ;
}

inline void pdm(int i, int ln, int rn)
{
	n[ls].lzm = (n[ls].lzm * n[i].lzm) % P;
	n[rs].lzm = (n[rs].lzm * n[i].lzm) % P;
	n[ls].lza = (n[ls].lza * n[i].lzm) % P;
	n[rs].lza = (n[rs].lza * n[i].lzm) % P;
	n[ls].sm = (n[ls].sm * n[i].lzm) % P;
	n[rs].sm = (n[rs].sm * n[i].lzm) % P;
	n[i].lzm = 1;
	return ;
}

inline long long as(int i, int l, int r)  // answer section
{
	if(n[i].l >= l && n[i].r <= r)
	    return n[i].sm;
	if(n[i].lzm != 1) pdm(i, mid - n[i].l + 1, n[i].r - mid);
	if(n[i].lza)      pda(i, mid - n[i].l + 1, n[i].r - mid);
	long long ans = 0;
	if(l <= mid) ans = (ans + as(ls, l, r)) % P;
	if(r > mid)  ans = (ans + as(rs, l, r)) % P;
	return ans;
}

inline void csa(int i, int l, int r, int x)
{
	if(n[i].l >= l && n[i].r <= r)
	{
		n[i].sm = (n[i].sm + x * (n[i].r - n[i].l + 1)) % P;
		n[i].lza = (n[i].lza + x) % P;
		return ;
	}
	if(n[i].lzm != 1) pdm(i, mid - n[i].l + 1, n[i].r - mid);
	if(n[i].lza)      pda(i, mid - n[i].l + 1, n[i].r - mid);
	if(mid < r)  csa(rs, l, r, x);
	if(l <= mid) csa(ls, l, r, x);
	ud(i);
	return ;
}

inline void csm(int i, int l, int r, int x)
{
	if(n[i].l >= l && n[i].r <= r)
	{
		n[i].sm = (n[i].sm * x) % P;
		n[i].lzm = (n[i].lzm * x) % P;
		n[i].lza = (n[i].lza * x) % P;
		return ;
	}
	if(n[i].lzm != 1) pdm(i, mid - n[i].l + 1, n[i].r - mid);
	if(n[i].lza)      pda(i, mid - n[i].l + 1, n[i].r - mid);
	if(mid < r)  csm(rs, l, r, x);
	if(l <= mid) csm(ls, l, r, x);
	ud(i);
	return ;
}

inline void prt()
{
	putchar('#');
	for(register int i = 1; i <= N; i++)
	    printf("%lld ", as(1, i, i));
	putchar(10);
	return ;
}

int main()
{
	N = read();
	M = read();
	P = read();
	built(1, 1, N);
	//prt();
	register int f, x, y, z;
	for(register int i = 1; i <= M; i++)
	{
		f = read();
		if(f == 1)
		{
			x = read();
			y = read();
			z = read();
			csm(1, x, y, z);
			//prt();
		}
		else if(f == 2)
		{
			x = read();
			y = read();
			z = read();
			csa(1, x, y, z);
			//prt();
		}
		else
		{
			x = read();
			y = read();
			print(as(1, x, y));
			putchar(10);
			//prt();
		}
	}
	return 0;
}

结合树剖食用更佳(~ ̄▽ ̄)~

你可能感兴趣的:(板子)