题目:http://acm.hdu.edu.cn/showproblem.php?pid=4670
题目大意:给你一棵有n个顶点的树,每个节点有一个权值,给你k个prime,每个权值都可以由这k个prime的幂次方的和组成,问你在树上有多少条路径,使这条路径上的点的权值积是一个立方数。
思路:一个数是立方数,当且只当它的拆分成的所有质因子都是幂都是3的倍数,而质因子的4、7次幂和1次幂一样,5次、8次和2次一样,则对于幂次的个数,我们都对3取余,结果不变。每次选取这棵树的重心,然后算出所有经过这个点的路径数(暴搜),最后全部加起来即可。一个儿子,一个儿子搜过来,如果当前这个儿子这条路径的各因数幂次都知道了,那么我们只要加上之前有的它的互补的幂次的个数就好。互补就是说,加起来都是3的倍数,也就是全是0,比如:当前路径为0 1 2,那么只要看看 0 2 1这条路径之前出现的个数,加起来就行了。还有,这里的互补不能用 2 2 2 的满状态去剪应该是 3 3 3 的,比如:1 1 1,那么一剪还是1 1 1,其实是2 2 2。这里还有一个地方特别需要注意:k最大为30,可以用lld存下来表示状态,但是用数组哈希显然是不行的,用map,清零和哈希都很方便。
其实挺简单的说,功力太差,有个地方一直搞来搞去,搜的那条路径不包括根节点,然后更新 hash 的时候要加上根节点,初始化就是只包含根节点的状态为1。调了一个下午,挫了。。= =
搓代码一份,如下:
#pragma comment(linker, "/STACK:10240000000000,10240000000000") #include<cstdio> #include<cstring> #include<map> #include<vector> #include<algorithm> using namespace std; typedef __int64 lld; const int MAXN = 55555 ; int n,k; struct Edge { int next,t; } edge[MAXN<<1]; int tot ,head[MAXN]; void add_edge(int s,int t) { edge[tot].t = t; edge[tot].next = head[s]; head[s] = tot++; } struct Node { int cnt[33]; } node[MAXN]; int num[MAXN],maxv[MAXN]; int vis[MAXN]; void get_size(int u,int fa) { num[u] = 1; maxv[u] = 0; for(int e = head[u];e!=-1;e = edge[e].next) { int v = edge[e].t; if(vis[v]||v==fa) continue; get_size(v,u); num[u] += num[v]; maxv[u] = max(maxv[u],num[v]); } } int minn ; void find_root(int u,int fa,int &root,int sum) { int tmp = max(sum - num[u],maxv[u]); if(tmp < minn ) { minn = tmp; root = u; } for(int e = head[u];e!=-1;e=edge[e].next) { int v = edge[e].t; if(vis[v]||fa==v) continue; find_root(v,u,root,sum); } } int get_root(int u) { get_size(u,-1); int sum = num[u]; minn = n; int root = u; find_root(u,-1,root,sum); return root; } lld exp[33]; void init() { exp[0] = 1; for(int i = 1;i<=30;i++) exp[i] = exp[i-1] * 3; } map <lld,int> sta; int ss[33]; int ret; vector <lld> vec; void dfs(int u,int fa,int root) { for(int i = 0;i<k;i++) ss[i] = (ss[i] + node[u].cnt[i])%3; lld cc = 0,cc2 = 0; for(int i = 0;i<k;i++) { cc += (3 - ss[i])%3*exp[i]; cc2 += (ss[i]+node[root].cnt[i])%3*exp[i]; } vec.push_back(cc2); ret += sta[cc]; for(int e = head[u];e!=-1;e = edge[e].next) { int v = edge[e].t; if(vis[v]||v==fa) continue; dfs(v,u,root); for(int i = 0;i<k;i++) ss[i] = (ss[i] - node[v].cnt[i] + 3)%3; } } int count(int u) { ret = 0 ; sta.clear(); lld cc = 0; for(int i = 0;i<k;i++) { cc += node[u].cnt[i]*exp[i]; } sta[cc] = 1; if(cc == 0) ret = 1; for(int e = head[u] ; e!=-1;e = edge[e].next) { int v = edge[e].t; if(vis[v]) continue; memset(ss,0,sizeof(ss)); vec.clear(); dfs(v,u,u); for(int i = 0;i<vec.size();i++) sta[vec[i]] ++ ; } //printf("ret = %d\n",ret); return ret; } int ans ; void solve(int u) { int root = get_root(u); //printf("root = %d\n",root); vis[root] = 1; ans += count(root); for(int e = head[root] ; e!=-1 ;e = edge[e].next) { int v = edge[e].t; if(vis[v]) continue; solve(v); } } lld pri[33]; int main() { init(); while(~scanf("%d",&n)) { scanf("%d",&k); for(int i = 0;i<k;i++) scanf("%I64d",&pri[i]); for(int i = 0;i<n;i++) { lld tmp; scanf("%I64d",&tmp); memset(node[i].cnt,0,sizeof(node[i].cnt)); for(int j = 0;j<k;j++) { while(tmp&&(tmp%pri[j]==0)) { node[i].cnt[j] ++; node[i].cnt[j] = node[i].cnt[j]%3; tmp = tmp/pri[j]; } if( tmp == 0 ) break; } } tot=0; memset(head,-1,sizeof(head)); int a,b; for(int i = 1;i<n;i++) { scanf("%d%d",&a,&b); a--; b--; add_edge(a,b); add_edge(b,a); } memset(vis,0,sizeof(vis)); ans = 0; solve(0); printf("%d\n",ans); } return 0; } /* 6 2 2 3 36 36 36 36 36 36 1 2 2 3 3 4 4 5 5 6 */