有多组样例,每组样例第一行输入两个正整数n,m(2 <= n<=50000,1<=m <= 50000),接下来n-1行,每行3个正整数a b c,(1 <= a,b <= n , a != b , 1 <= c <= 1000000000).数据保证给的路使得任意两座城市互相可达。接下来输入m行,表示m个操作,操作有两种:一. 0 a b,表示更新第a条路的过路费为b,1 <= a <= n-1 ; 二. 1 a b , 表示询问a到b最少要花多少过路费。
树链剖分..
==debug了好久。。就是自己逗。。 QUERY函数写错了。。
#include
#include
#include
#include
#include
using namespace std;
const int N = 155555;
const int M = 1111111;
struct Edge{
int u,v,w;
Edge(){}
Edge(int _l,int _r,int _w){
u=_l, v=_r,w=_w;
}
}edge[M];
struct node{
int to,w;
node(){}
node(int _to,int _w){to=_to,w=_w;}
};
vector V[N];
int n,m;
int dep[N],fa[N],sz[N],son[N];
int tim[N],ra[N],top[N];
int tx;
int val[N];
void dfs1(int x,int Fa,int Deep){
fa[x] = Fa , sz[x] = 1 , dep[x] = Deep;
for(int i = 0;i < V[x].size();i++){
int v = V[x][i].to;
if( v!=Fa){
val[v] = V[x][i].w;
dfs1(v,x,Deep+1);
sz[x] += sz[v];
if(son[x]==-1 || sz[v] > sz[son[x]]){
son[x] = v;
}
}
}
}
void dfs2(int u,int Top){
tim[u] = ++tx; ra[tx] = u;
top[u] = Top;
if(son[u]!=-1)
dfs2(son[u],Top);
for(int i = 0;i < V[u].size();i++){
int v = V[u][i].to;
if(v!=fa[u] && v!=son[u]){
dfs2(v,v);
}
}
}
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define ll long long
ll sum[N<<2];
void Push_up(int rt){
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void build(int l,int r,int rt){
if(l==r){
sum[rt] = val[ra[l]];
return ;
}
int m=(l+r)>>1;
build(lson);
build(rson);
Push_up(rt);
}
int ql,qr;
int qv;
void update(int l,int r,int rt){
if(l==r){
sum[rt] = qv;
return;
}
int m = (l+r)>>1;
if(ql<=m)update(lson);
if(ql>m)update(rson);
Push_up(rt);
}
ll query(int l,int r,int rt){
if(ql<=l && r <= qr){
return sum[rt];
}
int m=(l+r)>>1;
ll ans = 0;
if(ql<=m) ans+=query(lson);
if(qr>m) ans+=query(rson);
return ans;
}
void UPDATE(int idx,int w){
int u = edge[idx].u, v = edge[idx].v;
if(dep[u] < dep[v]) swap(u,v);
val[u] = w;
qv = w;
ql=tim[u];
update(1,n,1);
}
ll QUERY(int x,int y){
int f1 = top[x] , f2 = top[y];
long long ans = 0;
while(f1!=f2){
if(dep[f1] dep[y]) swap(x,y);
ql = tim[x], qr = tim[y] , ans += query(1,n,1);
ans -= val[x];
return ans;
}
int main(){
while(scanf("%d %d",&n,&m)!=EOF){
for(int i=0;i<=n;i++) V[i].clear();
for(int i=1,u,v,w;i