poi 1741
题目:http://poj.org/problem?id=1741
题意:给你一棵最多 10^4 个点组成的树,每根树枝的长度最多为 10^3 ,问你两个点之间的距离<=k 的点对数。
思路:
楼教男人八题之一。。 显然,O(N^2) 找点对的方法是不行的,而 O(NK) (k<=10^9)的动态规划也是不行的。好吧,具体思路参见漆子超的论文:http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###,很好的一篇论文,在树上用分治来做。在减少递归的层次上,用了一个技巧,就是每次都找这棵树的重心,其在最差情况下的层数为 O(logN),如果不这样做,那么最差情况,即一根链的时候,层数高达O(N),会TLE。算出每个点到根节点的距离后,匹配找出 Depth[i]+Depth[j]<=K 时,先将 dis 数组排序,然后从两边开始找,这样找的时间复杂度为O(N),而排序是 O(logN),所以总的是O(NlogN)。综上,该算法总的时间复杂度为 Nlog(N)*log(N)。
看了论文以后,第一遍做,过了样例,一交,还是TLE了,一看别人AC的,因为是双向边,我用了一个 vis 数组,这样每次都要清零,耗费了很多时间,其实只要记录它的father ,不要走回去就好。改过来后,交了,WA了。这次检查对照了好久,直到找到下面 6 6 这组数据时才发现,原来是在找 Depth[i]+Depth[j]<=K且Belong[i]=Belong[j]数对(i,j)的个数时候错了。因为我每次都是dfs2(u,0,0),再算del的时候,应该是的dfs2(u,0,len),连接u、v的那条边要算进去,作为dis加上去的初始值。只能说,最后AC的时候好开心,又向男人进了一步。。 = =
代码如下:
#include<cstdio> #include<cstring> #include<vector> #include<algorithm> using namespace std; const int INF = 0x0fffffff ; const int MAXN = 11111 ; int n,m; struct Edge { int t,next,len; } edge[MAXN<<1]; int head[MAXN],tot; void add_edge(int s,int t,int len) { edge[tot].t=t; edge[tot].len=len; edge[tot].next = head[s]; head[s] = tot++; } int root; int getted[MAXN]; vector <int> node; int num[MAXN],maxv[MAXN]; void dfs1(int u,int fa) { node.push_back(u); num[u]=1; maxv[u]=0; for(int e = head[u] ;e!=-1;e=edge[e].next) { int v = edge[e].t; if(getted[v]||v==fa) continue; dfs1(v,u); num[u]+=num[v]; maxv[u]=max(maxv[u],num[v]); } } void get_root(int x) { node.clear(); dfs1(x,0); int minn=INF; int sum_node = num[x]; for(int i=0;i<node.size();i++) { int cur = node[i]; maxv[cur] = max(maxv[cur],sum_node-num[cur]); if(maxv[cur]<minn) { minn = maxv[cur]; root = cur; } } } vector <int> dis; void dfs2(int u,int fa,int s) { dis.push_back(s); for(int e = head[u];e!=-1;e =edge[e].next) { int v = edge[e].t; int len = edge[e].len; if(getted[v]||v==fa||s+len>m) continue; dfs2(v,u,s+len); } } void get_dis(int u,int dist) { dis.clear(); dfs2(u,0,dist); } int ans; void count_add() { get_dis(root,0); sort(dis.begin(),dis.end()); int j=dis.size()-1; for(int i = 0;i<j;) { if(dis[i]+dis[j]<=m) { ans+=j-i; i++; } else { j--; } } /*printf("dis+\n"); for(int i=0;i<dis.size();i++) printf("%d ",dis[i]); puts("");*/ } void count_del() { for(int e = head[root] ; e!=-1;e=edge[e].next) { int v = edge[e].t; int len = edge[e].len; if(getted[v]) continue; get_dis(v,len); sort(dis.begin(),dis.end()); /* puts("dis-"); for(int i=0;i<dis.size();i++) printf("%d ",dis[i]); puts("");*/ int j = dis.size()-1; for(int i=0;i<j;) { if(dis[i]+dis[j]<=m) { ans-=(j-i); i++; } else j--; } } } void solve(int x) { get_root(x); getted[root] = 1; //printf("root = %d\n",root); count_add(); //printf("ans1 = %d\n",ans); count_del(); //printf("ans2 = %d\n",ans); for(int e = head[root] ; e != -1; e =edge[e].next) { int v = edge[e].t; if(getted[v]) continue; solve(v); } } int main() { while(~scanf("%d%d",&n,&m)) { if(n+m==0) break; memset(head,-1,sizeof(head)); tot=0; int a,b,c; for(int i=1;i<n;i++) { scanf("%d%d%d",&a,&b,&c); add_edge(a,b,c); add_edge(b,a,c); } memset(getted,0,sizeof(getted)); ans=0; solve(1); printf("%d\n",ans); } return 0; } /* 5 4 1 2 3 1 3 1 1 4 2 3 5 1 6 6 1 2 3 1 3 1 1 4 2 3 5 1 5 6 1 0 0 */