学习博客:https://www.cnblogs.com/bztMinamoto/p/9489473.html
点分治板子
详见下方代码
#include
#define M 80009
using namespace std;
int read(){
int f=1,re=0;char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
const int inf=1e9+7;
int tot,n,k,first[M],nxt[M],to[M],w[M],siz[M],vis[M],s[M],top,dis[M],size,num,rt,ans;
void add(int x,int y,int z){nxt[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;}
void getroot(int u,int fa){//找重心
siz[u]=1;int maxn=0;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]||v==fa) continue;
getroot(v,u);
siz[u]+=siz[v];
maxn=max(maxn,siz[v]);
}maxn=max(size-siz[u],maxn);//注意此处是size-siz[u],不是n-siz[u]
if(maxn<num) rt=u,num=maxn;
}
void getdis(int u,int fa){
s[++top]=dis[u];
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]||v==fa) continue;
dis[v]=dis[u]+w[i];
getdis(v,u);
}
}
int solve(int u,int val){
dis[u]=val;top=0;
getdis(u,0);
sort(s+1,s+top+1);
int l=1,r=top,sum=0;
while(l<r){
if(s[l]+s[r]<=k) sum+=r-l,l++;
else r--;
}return sum;
}
void dfs(int u){
ans+=solve(u,0);vis[u]=1;//统计以该节点为根的答案
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
ans-=solve(v,w[i]);//删去它儿子的影响;因为第一次统计答案时,有来自同一棵子树的非法路径
size=siz[v]>siz[u]?n-siz[u]:siz[v];//size表示当前处理的子树大小
num=inf;
getroot(v,0);//注意是0
dfs(rt);
}
}
int main(){
n=read();
for(int i=1;i<n;i++){
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}k=read();
size=n,num=inf;
getroot(1,0);dfs(rt);
printf("%d\n",ans);
return 0;
}
点分治板子
直接将路径分为3类,余0,余1,余2,然后开个桶记录即可,最后计算答案
#include
#define int long long
#define M 40009
using namespace std;
int read(){
int f=1,re=0;char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
const int inf=1e9+7;
int tot,n,first[M],nxt[M],to[M],w[M],siz[M],vis[M],dis[M],size,num,rt,ans,cnt[3];
void add(int x,int y,int z){nxt[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;}
int gcd(int a,int b){
if(b==0) return a;
else return gcd(b,a%b);
}
void getroot(int u,int fa){//求重心
siz[u]=1;int maxn=0;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]||v==fa) continue;
getroot(v,u);
siz[u]+=siz[v];
maxn=max(maxn,siz[v]);
}maxn=max(size-siz[u],maxn);
if(maxn<num) rt=u,num=maxn;
}
void getdis(int u,int fa){
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]||v==fa) continue;
dis[v]=dis[u]+w[i];
getdis(v,u);
}cnt[dis[u]%3]++;
}
int solve(int u,int val){
dis[u]=val;
memset(cnt,0,sizeof(cnt));
getdis(u,0);
return cnt[0]*(cnt[1]+cnt[2])*2+(cnt[1]*cnt[1])+(cnt[2]*cnt[2]);
}
void dfs(int u){
ans+=solve(u,0);vis[u]=1;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
ans-=solve(v,w[i]);
size=siz[v]>siz[u]?n-siz[u]:siz[v];//注意对节点数的处理
num=inf;
getroot(v,0);//注意是0
dfs(rt);
}
}
signed main(){
n=read();
for(int i=1;i<n;i++){
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}size=n,num=inf;
getroot(1,0);dfs(rt);
int d=gcd(n*n,n*n-ans);
printf("%lld/%lld\n",(n*n-ans)/d,n*n/d);
return 0;
}
点分治板子
统计的时候, O ( n 2 ) O(n^2) O(n2)枚举统计,但会TLE,更好的做法目前还不会
#include
#define M 10009
using namespace std;
int read(){
int f=1,re=0;char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
const int inf=1e9+7;
int tot,n,m,first[M],nxt[M<<1],to[M<<1],w[M<<1],siz[M],vis[M],s[M],top,dis[M],size,num,rt,ans,cnt[10000009];
void add(int x,int y,int z){nxt[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;}
void getroot(int u,int fa){
siz[u]=1;int maxn=0;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
getroot(v,u);
siz[u]+=siz[v];
maxn=max(siz[v],maxn);
}maxn=max(maxn,size-siz[u]);//注意是size-siz[u];
if(maxn<num) rt=u,num=maxn;
}
void getdis(int u,int fa){
s[++top]=dis[u];
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+w[i];getdis(v,u);
}
}
void solve(int u,int val){
dis[u]=val;top=0;
getdis(u,0);
for(int i=1;i<=top;i++)
for(int j=i+1;j<=top;j++){
if(s[i]+s[j]>10000000) continue;
if(val) cnt[s[i]+s[j]]--;
else cnt[s[i]+s[j]]++;
}
}
void dfs(int u){
solve(u,0),vis[u]=1;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
solve(v,w[i]);
size=siz[v]>siz[u]?n-siz[u]:siz[v];
num=inf,rt=0,getroot(v,0);
dfs(rt);
}
}
int main(){
n=read(),m=read();
for(int i=1;i<n;i++){
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}size=n,num=inf;
getroot(1,0);dfs(rt);
for(int i=1;i<=m;i++){
int k=read();
if(cnt[k]) printf("AYE\n");
else printf("NAY\n");
}return 0;
}