(只跟别人标程拍过了,bzoj交不了)
#include
#include
#include
#include
#include
#define For(i,j,k) for(int i=(j);i<=(int)k;i++)
#define Forr(i,j,k) for(int i=(j);i>=(int)k;i--)
using namespace std;
int main(){
int n=50000,m=100000;
freopen("bzoj3575.in","w",stdout);
printf("%d %d\n",n,m);
For(i,1,n)printf("%d ",rand()%n+1);
puts("0 1");
For(i,2,n){
int u=rand()%(i-1)+1;
printf("%d %d\n",i,u);
}
For(i,1,m){
int u=rand()%n+1,v=rand()%n+1;
int a=rand()%n+1,b=rand()%n+1;
printf("%d %d %d %d \n",u,v,a,b);
}
return 0;
}
#include
#include
#include
#include
#include
#include
#define For(i,j,k) for(int i=(j);i<=(int)k;i++)
#define Forr(i,j,k) for(int i=(j);i>=(int)k;i--)
#define Set(a,b) memset(a,b,sizeof(a))
#define ll long long
#define Rep(i,u) for(int i=Begin[u],v=to[i];i;i=Next[i],v=to[i])
using namespace std;
const int N=50010,INF=0x3f3f3f3f;
inline void read(int &x){
x=0;char c=getchar();int f(0);
while(c<'0'||c>'9')f|=(c=='-'),c=getchar();
while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+(c^48),c=getchar();
x=f?-x:x;
}
int Begin[N],to[N<<1],Next[N<<1],e,col[N],ans,n,m;
int dfn[N],p[N],blo,cnt,blon,rt;
inline void add(int x,int y){
to[++e]=y,Next[e]=Begin[x],Begin[x]=e;
}
struct Tree{
int fa[N][18],dep[N],st[N],top;
Tree(){top=0;}
inline int dfs(int u){
int s=0;
dfn[u]=++cnt;
Rep(i,u)
if(!dep[v]){
dep[v]=dep[u]+1,fa[v][0]=u,s+=dfs(v);
if(s>=blo){
blon++;
For(j,1,s)p[st[top--]]=blon;
s=0;
}
}
st[++top]=u;
return s+1;
}
inline void init(){
For(i,1,17)
For(j,1,n)
fa[j][i]=fa[fa[j][i-1]][i-1];
}
inline int query(int x,int y){
if(dep[x]int d=dep[x]-dep[y];
for(int i=0;(1<if(d&(1<if(x==y)return x;
Forr(i,17,0)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
}t;
struct node{
int l,id,r,v,a,b;
}a[N<<1];
inline bool cmpid(node u,node v){
return u.idinline bool cmpb(node u,node v){
return p[u.l]inline void init(){
read(n),read(m);
For(i,1,n)read(col[i]);
For(i,1,n){
int x,y;
read(x),read(y);
if(x>y)swap(x,y);
if(x==0)t.dep[rt=y]=1;
else add(x,y),add(y,x);
}
blo=(int)sqrt(n);
t.dfs(rt);
t.init();
blon++;
while(t.top)p[t.st[t.top--]]=blon;
For(i,1,m){
read(a[i].l),read(a[i].r);
if(dfn[a[i].l]>dfn[a[i].r])swap(a[i].l,a[i].r);
read(a[i].a),read(a[i].b);a[i].id=i;
}
}
int vis[N],P[N];
inline void rev(int u){
if(!vis[u]){vis[u]=1,P[col[u]]++;if(P[col[u]]==1)ans++;}
else {vis[u]=0,P[col[u]]--;if(P[col[u]]==0)ans--;}
}
inline void update(int u,int v){
while(u!=v){
if(t.dep[u]>t.dep[v])rev(u),u=t.fa[u][0];
else rev(v),v=t.fa[v][0];
}
}
inline void solve(){
sort(a+1,a+m+1,cmpb);
int lca=t.query(a[1].l,a[1].r);
update(a[1].l,a[1].r);
rev(lca);
a[1].v=ans-(a[1].a!=a[1].b&&P[a[1].a]&&P[a[1].b]);
rev(lca);
For(i,2,m){
update(a[i-1].l,a[i].l);
update(a[i-1].r,a[i].r);
lca=t.query(a[i].l,a[i].r);
rev(lca);
a[i].v=ans-(a[i].a!=a[i].b&&P[a[i].a]&&P[a[i].b]);
rev(lca);
}
sort(a+1,a+m+1,cmpid);
For(i,1,m)printf("%d\n",a[i].v);
}
int main(){
init();
solve();
return 0;
}