树链剖分-重链剖分

P3384 【模板】重链剖分/树链剖分

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 


#include 
#include 
#include 
#include 
#include 
using namespace std;


inline int read(int& x) {
	char ch = getchar();
	int f = 1; x = 0;
	while (ch > '9' || ch < '0') { if (ch == '-')f = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + ch - '0'; ch = getchar(); }
	return x * f;
}
//void ReadFile() {
//	FILE* stream1;
//	freopen_s(&stream1,"in.txt", "r", stdin);
//	freopen_s(&stream1,"out.txt", "w", stdout);
//}

static auto speedup = []() {ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); return nullptr; }();

typedef long long int ll;
const int maxn = 2e5 + 7;
int mod;

class segment {
public:
	segment() {
		memset(m_data, 0, sizeof(m_data));
		memset(m_lazy, 0, sizeof(m_lazy));
	}
	void build(int l,int r,int *arr,int idx) {
		if (l == r) {
			m_data[idx] = arr[l] % mod;
			return;
		}
		int mid = (l + r) >> 1;
		int lson = idx * 2,rson = lson + 1;
		build(l,mid,arr,lson);
		build(mid + 1, r, arr, rson);
		push_up(idx);
	}
	void push_up(int idx) {
		m_data[idx] = m_data[idx << 1] + m_data[idx << 1 | 1];
	}
	void push_down(int l,int r,int idx) {
		if (!m_lazy[idx])return;
		ll lson = idx * 2, rson = lson + 1,val = m_lazy[idx];
		int mid = (l + r) >> 1,lcnt = mid - l + 1,rcnt = r - mid;
		m_lazy[lson] = (m_lazy[lson] + val) % mod;
		m_lazy[rson] = (m_lazy[rson] + val) % mod;
		m_data[lson] = (m_data[lson] + val * lcnt) % mod;
		m_data[rson] = (m_data[rson] + val * rcnt) % mod;
		m_lazy[idx] = 0;
	}

	void _modify(int l,int r,int idx,int cl, int cr, int val) {
		if (l > cr || r < cl) return;
		if (l >= cl && r <= cr) {
			m_data[idx] = (m_data[idx] + val * (r - l + 1)) % mod;
			m_lazy[idx] = (m_lazy[idx] + val) % mod;
			return;
		}
		push_down(l, r, idx);
		int mid = (l + r) >> 1,lson = idx * 2,rson = lson + 1;
		
		if(mid >= cl) _modify(l,mid,lson,cl,cr,val);
		if(mid < cr) _modify(mid + 1,r, rson, cl, cr, val);
		push_up(idx);
	}
	void modify(int l, int r, int val) {
		_modify(1, cnt,1, l, r, val);
	}

	ll _query(int l, int r, int idx, int ql, int qr) {
		if (l > qr || r < ql) return 0;
		if (l >= ql && r <= qr) {
			return m_data[idx];
		}
		push_down(l, r, idx);
		int mid = (l + r) >> 1, lson = idx * 2, rson = lson + 1;
		ll ans = 0;

		if (mid >= ql) ans = (ans + _query(l, mid, lson, ql, qr)) % mod;
		if (mid < qr) ans = (ans + _query(mid + 1, r, rson, ql, qr)) % mod;
		return ans;
	}
	ll query(int l, int r) {
		return _query(1, cnt, 1, l, r);
	}
public:
	ll m_data[maxn << 2];
	ll m_lazy[maxn << 2];
	int cnt;
};

segment tree;

struct Tail
{
	int to, next;
	Tail() :to(0), next(0) {}
}tail[maxn];

int head[maxn] = { 0 }, tot = -1;
bool visit[maxn] = { 0 };

//head数组 相当于 尾巴节点的头指针(第一个)
//tail 是尾巴节点  next 指向 下一个尾巴 节点下标
void AddEdge(int _head, int _tail) {
	//++tot 相当于 创建一个新节点, .to 记录尾巴节点
	tail[++tot].to = _tail;
	//head[_head] 相当于 邻接表里面的头指针 , 将新节点的 next 指向 头指针
	tail[tot].next = head[_head];
	//将头指针指向 新节点
	head[_head] = tot;
}

int fa[maxn], dep[maxn], siz[maxn], son[maxn];

void dfs(int u, int f) {
	fa[u] = f;
	dep[u] = dep[f] + 1;
	siz[u] = 1;
	int maxsize = 0;
	for (int cur = head[u]; cur; cur = tail[cur].next) {
		int h = tail[cur].to;
		if (h == f)continue;
		dfs(h, u);
		siz[u] += siz[h];
		if (siz[h] > maxsize) {
			son[u] = h;
			maxsize = siz[h];
		}
	}
}

int tim = 0,dfn[maxn],top[maxn],wt[maxn],w[maxn];

void dfs2(int u, int topf) {
	dfn[u] = ++tim;
	top[u] = topf;
	wt[tim] = w[u];
	if (!son[u]) return;
	dfs2(son[u], topf);
	for (int cur = head[u]; cur; cur = tail[cur].next) {
		int h = tail[cur].to;
		if (h == fa[u] || h == son[u])continue;
		dfs2(h, h);
	}
}

int n, m, r;

void subTreeAdd(int idx,int val) {
	tree.modify(dfn[idx],dfn[idx] + siz[idx] - 1, val);
}

ll subTreeSum(int idx) {
	return tree.query(dfn[idx], dfn[idx] + siz[idx] - 1);
}

void routeAdd(int l, int r,int val) {
	while (top[l] != top[r]) {
		if (dep[top[l]] < dep[top[r]]) swap(l, r);
		tree.modify(dfn[top[l]] , dfn[l], val);
		l = fa[top[l]];
	}
	if (dep[l] > dep[r])swap(l, r);
	tree.modify(dfn[l], dfn[r], val);
}

int queryRoute(int l, int r) {
	ll ans = 0;
	while (top[l] != top[r]) {
		if (dep[top[l]] < dep[top[r]]) swap(l, r);
		ans = (ans + tree.query(dfn[top[l]], dfn[l])) % mod;
		l = fa[top[l]];
	}
	if (dep[l] > dep[r])swap(l, r);
	ans = (ans + tree.query(dfn[l], dfn[r])) % mod;
	return ans;
}
int main()
{
	//ReadFile();
	cin >> n >> m >> r >> mod;
	for (int i = 1; i <= n; i++) {
		cin >> w[i];
	}
	int a, b,c,d;
	for (int i = 1; i < n; i++) {
		cin >> a >> b;
		AddEdge(a, b);
		AddEdge(b, a);
	}
	dfs(r, r);
	dfs2(r, r);
	tree.cnt = n;
	tree.build(1,n,wt,1);

	for (int i = 1; i <= m; i++) {
		cin >> a;
		if (a == 1) {
			cin >> b >> c >> d;
			routeAdd(b, c, d);
		}
		else if (a == 2) {
			cin >> b >> c;
			cout << queryRoute(b, c) << endl;
		}
		else if (a == 3) {
			cin >> b >> c;
			subTreeAdd(b, c);
		}
		else {
			cin >> b;
			cout << subTreeSum(b) << endl;
		}
	}
	return 0;
}

你可能感兴趣的:(树论,c++)