『树上差分·线段树合并』雨天的尾巴

P r o b l e m \mathrm{Problem} Problem

N个点,形成一个树状结构。有M次发放,每次选择两个点x,y对于x到y的路径上(含x,y)每个点发一袋Z类型的物品。

完成所有发放后,每个点存放最多的是哪种物品。

S o l u t i o n \mathrm{Solution} Solution

我们可以用数组 c n t [ x ] [ v ] cnt[x][v] cnt[x][v]表示节点 x x x中权值 v v v的出现次数。

由于暴力修改的时间复杂度过高,但由于这是路径修改,我们可以考虑树上差分。

c n t [ x ] [ v ] cnt[x][v] cnt[x][v] c n t [ y ] [ v ] cnt[y][v] cnt[y][v]加上 1 1 1 c n t [ L c a ( x , y ) ] [ v ] cnt[\mathrm{Lca}(x,y)][v] cnt[Lca(x,y)][v] c n t [ f a L C A ( x , y ) ] [ v ] cnt[fa_{\mathrm{LCA(x,y)}}][v] cnt[faLCA(x,y)][v]减去 1 1 1.

但是我们再累加子树信息的时候,使用线段树合并来实现即可。

这里说一下一些小坑点:

  • 离散化以后需要保持数值的相对大小关系不变。
  • 最后当每一个位置的出现次数为 0 0 0时,需要输出 0 0 0

C o d e \mathrm{Code} Code

#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;
const int N = 1000000;
const int M = 40;

int n, m;
int tot = 0, cnt = 0;
int dep[N], size[N], f[N][M], root[N], ans[N], last[N];
struct SegmentTree { int lc, rc, cnt, ans; } tr[N*10];
struct node { int x, y, v, num; } q[N];
vector <int> a[N];
map <int,int> tag;

int read(void)
{
	int s = 0, w = 0; char c = getchar();
	while (c < '0' || c > '9') w |= c == '-', c = getchar();
	while (c >= '0' && c <= '9') s = s*10+c-48, c = getchar();
	return w ? -s : s;
}

void dfs(int x,int fa)
{
	dep[x] = dep[fa] + 1, f[x][0] = fa;
	for (int i=0;i<a[x].size();++i)
	{
		int y = a[x][i];
		if (y == fa) continue;
		dfs(y,x);
	}
	return;
}

void dp(void)
{
	for (int j=1;j<=20;++j)
	    for (int i=1;i<=n;++i)
	        f[i][j] = f[f[i][j-1]][j-1];
	return;
}

int LCA(int x,int y)
{
	if (dep[x] < dep[y]) swap(x,y);
	for (int i=0,d=dep[x]-dep[y];i<=20;++i)
	    if ((d >> i) & 1) x = f[x][i];
	if (x == y) return x;
	for (int i=20;i>=0;--i)
	    if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}

void updata(int p)
{
	if (tr[tr[p].lc].cnt >= tr[tr[p].rc].cnt) 
		tr[p].cnt = tr[tr[p].lc].cnt, tr[p].ans = tr[tr[p].lc].ans;
	else tr[p].cnt = tr[tr[p].rc].cnt, tr[p].ans = tr[tr[p].rc].ans;
}

void insert(int &p,int l,int r,int x,int v)
{
	if (p == 0) p = ++ tot;
	if (l == r) return tr[p].cnt += v, tr[p].ans = x, void();
	int mid = l + r >> 1;
	if (x <= mid) insert(tr[p].lc,l,mid,x,v);
	if (x > mid) insert(tr[p].rc,mid+1,r,x,v); 
	updata(p);
	return;
}

int merge(int p,int q,int l,int r)
{
	if (p == 0) return q;
	if (q == 0) return p;
	if (l == r) return tr[p].cnt += tr[q].cnt, p;
	int mid = l + r >> 1;
	tr[p].lc = merge(tr[p].lc,tr[q].lc,l,mid);
	tr[p].rc = merge(tr[p].rc,tr[q].rc,mid+1,r);
	updata(p);
	return p;
}

void dfs2(int x,int fa)
{
	for (int i=0;i<a[x].size();++i)
	{
		int y = a[x][i];
		if (y == fa) continue;
		dfs2(y,x);
		root[x] = merge(root[x],root[y],1,m);
	}
	if (tr[root[x]].cnt > 0)
		ans[x] = tr[root[x]].ans;
	return;
}

bool cmp(node p1,node p2) {
	return p1.v < p2.v;
}

int main(void)
{
	n = read(), m = read();
	for (int i=1,x,y;i<n;++i)
	{
		x = read(), y = read();
		a[x].push_back(y);
		a[y].push_back(x);
	}
	dfs(1,0), dp();
	for (int i=1;i<=m;++i)
		q[i] = node{read(),read(),read()};
	sort(q+1,q+m+1,cmp);
	for (int i=1;i<=m;++i) 
	{
	    if (q[i].v ^ q[i-1].v || i == 1) q[i].num = ++ cnt;
	    else q[i].num = q[i-1].num;
	    tag[q[i].num] = q[i].v;
	}
	for (int i=1,x,y,v;i<=m;++i)
	{
		x = q[i].x, y = q[i].y, v = q[i].num;
		int t = LCA(x,y);
		insert(root[x],1,m,v,1);
		insert(root[y],1,m,v,1);
		insert(root[t],1,m,v,-1);
		if (t != 1) insert(root[f[t][0]],1,m,v,-1); 
	}
	dfs2(1,0);
	for (int i=1;i<=n;++i) printf("%d\n", tag[ans[i]]);
	return 0;
}

你可能感兴趣的:(线段树,树上差分)