[http://poj.org/problem?id=3237] (题目链接)
树链剖分模板题,然而这150+行的程序我调了一天,历经艰辛,终于ac。。
题意:给出一个n个节点的带权树,要求维护操作:1.求出树上两点之间的边权的最大值;2.更改一条边上的权值;3.将树上两点之间的所有边权取各自的相反数。
solution
神奇的树链剖分+线段树维护查询和修改操作。
树链剖分时,我们将每条边的权值转换为除树根外每个节点上的权值(也就是对于每个节点与它父亲的边的权值转换到了自己的权值)。
之后就是标准的树链剖分后跑线段树了,那个全部取相反数的操作其实是一样的,树链剖分相关知识请见 http://blog.sina.com.cn/s/blog_7a1746820100wp67.html
编程时请注意细节,邻接表写错了就悲剧了。。
为了避免各oier先WA后TLE在狂RE,附上造数据程序:
#include
#include
#include
#include
#include
using namespace std;
int main() {
int i,j,k;
freopen("aaa.in","r",stdin);freopen("aaa.in","w",stdout);
puts("1");
srand((unsigned)time(NULL));
int n=rand()%1000+5;
printf("%d\n",n);
for(i=1;iprintf("%d %d %d\n",i+1,rand()%i+1,rand()%14513546);
for(i=1;iint opt=rand()%3;
if(opt==0)
printf("CHANGE %d %d\n",rand()%(n-1)+1,rand()%42534567);
else if(opt==1) {
int a=0,b=0;
while(a==b)a=rand()%n+1,b=rand()%n+1;
printf("NEGATE %d %d\n",a,b);
}
else {
int a=0,b=0;
while(a==b)a=rand()%n+1,b=rand()%n+1;
printf("QUERY %d %d\n",a,b);
}
}
puts("DONE");
fclose(stdin);fclose(stdout);
return 0;
}
ac代码:
// poj3237
#include
#include
#include
#include
#include
#include
#define MOD 1000000007
#define inf 2147483640
#define LL long long
#define free(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout);
using namespace std;
inline int getint() {
int x=0,f=1;char ch=getchar();
while (ch>'9' || ch<'0') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int maxn=100010;
struct edge {int to,next,w;}e[maxn<<2];
struct tree {int l,r,tag,mn,mx;}tr[maxn<<2];
int pos[maxn],deep[maxn],head[maxn],bin[20],fa[maxn][20],size[maxn],to[maxn],bl[maxn];
int cnt,P,n;
void insert(int u,int v,int w) {
e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;e[cnt].w=w;
e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt;e[cnt].w=w;
}
void solve(int &x,int &y) {
int t=x;x=-y;y=-t;
}
void update(int k)
{
tr[k].mn=min(tr[k<<1].mn,tr[k<<1|1].mn);
tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
}
void pushdown(int k) {
int l=tr[k].l,r=tr[k].r;
if (l==r || !tr[k].tag) return;
tr[k].tag=0;
tr[k<<1].tag^=1,tr[k<<1|1].tag^=1;
solve(tr[k<<1].mn,tr[k<<1].mx);
solve(tr[k<<1|1].mn,tr[k<<1|1].mx);
}
void build(int k,int s,int t) {
tr[k].l=s,tr[k].r=t,tr[k].tag=0,tr[k].mn=inf,tr[k].mx=-inf;
if (s==t) return;
int mid=(s+t)>>1;
build(k<<1,s,mid);
build(k<<1|1,mid+1,t);
}
void change(int k,int x,int val) {
pushdown(k);
int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1;
if (l==r) {tr[k].mn=tr[k].mx=val;return;}
if (x<=mid) change(k<<1,x,val);
else change(k<<1|1,x,val);
update(k);
}
void rever(int k,int x,int y) {
pushdown(k);
int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1;
if (l==x && r==y) {solve(tr[k].mn,tr[k].mx);tr[k].tag=1;return;}
if (y<=mid) rever(k<<1,x,y);
else if (x>mid) rever(k<<1|1,x,y);
else rever(k<<1,x,mid),rever(k<<1|1,mid+1,y);
update(k);
}
int query(int k,int x,int y) {
pushdown(k);
int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1;
if (l==x && r==y) return tr[k].mx;
if (y<=mid) return query(k<<1,x,y);
else if (x>mid) return query(k<<1|1,x,y);
else return max(query(k<<1,x,mid),query(k<<1|1,mid+1,y));
}
void dfs1(int x) {
size[x]=1;
for (int i=1;i<=13;i++) {
if (bin[i]<=deep[x]) fa[x][i]=fa[fa[x][i-1]][i-1];
else break;
}
for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0]) {
deep[e[i].to]=deep[x]+1;
fa[e[i].to][0]=x;
dfs1(e[i].to);
size[x]+=size[e[i].to];
}
}
void dfs2(int x,int chain) {
/*
if (x==22) {
++P;
--P;
}
*/
bl[x]=chain;
pos[x]=++P;
int k=0;
for (int i=head[x];i;i=e[i].next) {
if (e[i].to!=fa[x][0]) {
if (size[e[i].to]>size[k]) k=e[i].to;
}
else {
to[i>>1]=pos[x];//记录每个节点在线段树上的标号
change(1,pos[x],e[i].w);//将权值插入线段树
}
}
if (!k) return;
dfs2(k,chain);
for (int i=head[x];i;i=e[i].next)
if (e[i].to!=fa[x][0] && e[i].to!=k) dfs2(e[i].to,e[i].to);
}
int lca(int x,int y) {
if (deep[x]y]) swap(x,y);
int t=deep[x]-deep[y];
for (int i=0;i<=13;i++) if (t&bin[i]) x=fa[x][i];
for (int i=13;i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
if (x==y) return x;
return fa[x][0];
}
int solvequery(int x,int f) {
int mx=-inf;
while (bl[x]!=bl[f]) {
mx=max(mx,query(1,pos[bl[x]],pos[x]));
x=fa[bl[x]][0];
}
if (pos[f]+1<=pos[x]) mx=max(mx,query(1,pos[f]+1,pos[x]));
return mx;
}
void solverever(int x,int f) {
while (bl[x]!=bl[f]) {
rever(1,pos[bl[x]],pos[x]);
x=fa[bl[x]][0];
}
if (pos[f]+1<=pos[x]) rever(1,pos[f]+1,pos[x]);
}
int main() {
free("aaa");
int T=getint();
bin[0]=1;for (int i=1;i<15;i++) bin[i]=bin[i-1]<<1;
while (T--) {
P=0,cnt=1;//便于将边权转为点权
memset(head,0,sizeof(head));
memset(deep,0,sizeof(deep));
memset(fa,0,sizeof(fa));
n=getint();
for (int i=1;iint u=getint(),v=getint(),w=getint();
insert(u,v,w);
}
build(1,1,n);
dfs1(1);
dfs2(1,1);
char ch[10];
while (scanf("%s",ch+1)) {
if (ch[1]=='D') break;
int x=getint(),y=getint();
if (ch[1]=='Q') {
int f=lca(x,y);
printf("%d\n",max(solvequery(x,f),solvequery(y,f)));
}
if (ch[1]=='C') change(1,to[x],y);
if (ch[1]=='N') {
int f=lca(x,y);
solverever(x,f);solverever(y,f);
}
}
}
fclose(stdin);fclose(stdout);
return 0;
}