wa了12发的点分治终于过了,就是xjb乱搞题……
维护了六个量,(u到v的链上出现的第一种颜色col,col的次数cnt,最后一种链的颜色las,las的次数num,链的长度len,链的权值w)
统计必过u的答案的时候,
用 任意两个合法的 减去 任意两个col相同的合法的,作异色答案
加上col相同且长度均为1的(不抵消的)合法的,作同色不抵消,
加上col相同且长度之和大于等于3的(抵消的)合法的,作同色抵消
减去全在v的答案,然后点分治下去即可
统计答案可以用双指针统计,但是既然有sort的log,二分只是常数大一点…
点分治在统计(rt,u)的一条链时,实际是放入了一个rt=0的rt,所以统计上了这样的点对
#include
#include
#include
#include
#include
using namespace std;
#define pb push_back
typedef long long ll;
const int N=1e5+10;
int head[N],cnt;
struct edge{int v,nex;ll w;}e[2*N];
void add(int u,int v,ll w){e[++cnt]=edge{v,head[u],w};head[u]=cnt;}
bool vis[N];
int n,r,u,v;
ll k,res,w;
int siz,f[N],sz[N],rt;
struct node{
ll col,cnt,len,w,las,num;
}d[N],q[N],real;
vectornow,tmp,my[4];
bool cmp1(const node &a,const node &b){
return a.w-1ll*k*a.len<1ll*b.w-k*b.len;
}
bool cmp2(const node &a,const node &b){
if(a.col!=b.col)return a.col=1ll*k*(q[i].len+q[mid].len))y=mid-1;
else x=mid+1;
}
ans+=(r-x+1);//[x,r]
if(q[i].w+q[i].w>=1ll*k*(q[i].len+q[i].len))ans--;
}
sort(q+1,q+r+1,cmp2);
for(int i=1;i<=r;){
int j=i;
now.clear();tmp.clear();
for(int z=1;z<=3;++z){
my[z].clear();
}
tmp.pb(q[i]);
if(q[i].cnt==1)now.pb(q[i]);
if(q[i].cnt){
real=q[i];
if(q[i].cnt<=2)real.w-=q[i].col*q[i].cnt;
my[min(q[i].cnt,3ll)].pb(real);
}
while(j+1<=r && q[j+1].col==q[i].col){
j++;
tmp.pb(q[j]);
if(q[j].cnt==1)now.pb(q[j]);
if(q[j].cnt){
real=q[j];
if(q[j].cnt<=2)real.w-=q[j].col*q[j].cnt;
my[min(q[j].cnt,3ll)].pb(real);
}
}
sort(now.begin(),now.end(),cmp1);
sort(tmp.begin(),tmp.end(),cmp1);
for(int z=1;z<=3;++z){
sort(my[z].begin(),my[z].end(),cmp1);
}
int up=tmp.size();up--;
for(int z=0;z<=up;++z){
int x=0,y=up;
while(x<=y){
int mid=(x+y)/2;
if(tmp[z].w+tmp[mid].w>=1ll*k*(tmp[z].len+tmp[mid].len))y=mid-1;
else x=mid+1;
}
ans-=(up-x+1);
if(tmp[z].w+tmp[z].w>=1ll*k*(tmp[z].len+tmp[z].len))ans++;
}
up=now.size();up--;
for(int z=0;z<=up;++z){
int x=0,y=up;
while(x<=y){
int mid=(x+y)/2;
if(now[z].w+now[mid].w>=1ll*k*(now[z].len+now[mid].len))y=mid-1;
else x=mid+1;
}
ans+=(up-x+1);
if(now[z].w+now[z].w>=1ll*k*(now[z].len+now[z].len))ans--;
}
for(int s=1;s<=3;++s){
int sz=my[s].size();
for(int z=1;z<=3;++z){
if(s+z<3)continue;
int up=my[z].size();up--;
for(int h=0;h=1ll*k*(my[s][h].len+my[z][mid].len))y=mid-1;
else x=mid+1;
}
ans+=(up-x+1);
if(s==z && my[s][h].w+my[z][h].w>=1ll*k*(my[s][h].len+my[z][h].len))ans--;
}
}
}
i=j+1;
}
return ans/2;
}
void dfs(int u){
//每次用在u的子树里任取减去在v的子树里的答案
//每次只计算 必经过u的答案
res+=cal(u,0,0,0,0,0,0);
vis[u]=1;
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
ll w=e[i].w;
if(vis[v])continue;
res-=cal(v,w,1,1,w,w,1);
getrt(v,u,0);//获得正确的sz[v]
siz=sz[v];rt=0;
getrt(v,u,1);
dfs(rt);
}
}
int main(){
int T;
scanf("%d",&T);
while(T--){
scanf("%d%lld",&n,&k);
init(n);
for(int i=1;i