题目大意:
给出一棵树, 顶点数 n <= 100000
每个点有一个权值
给出K
询问这棵树中两点路径上的点的权值乘积对1e6 + 3取模之后等于K的路径的两个端点形成的点对中字典序最小的
大致思路:
点分治第二题
用f[i]表示从当前的子树根开始到某个点的路径乘积为i的点的最小标号
于是对于每一次分治, 去掉重心x之后, 依次处理每一棵子树来使得f的来源不一样
注意要预处理逆元
代码如下:
Result : Accepted Memory : 24420 KB Time : 2823 ms
/* * Author: Gatevin * Created Time: 2015/10/14 9:59:26 * File Name: Sakura_Chiyo.cpp */ #include<iostream> #include<sstream> #include<fstream> #include<vector> #include<list> #include<deque> #include<queue> #include<stack> #include<map> #include<set> #include<bitset> #include<algorithm> #include<cstdio> #include<cstdlib> #include<cstring> #include<cctype> #include<cmath> #include<ctime> #include<iomanip> using namespace std; const double eps(1e-8); typedef long long lint; #define maxn 100010 struct Edge { int u, v, nex; Edge(){} Edge(int _u, int _v, int _nex) { u = _u, v = _v, nex = _nex; } }; Edge edge[maxn << 1]; int head[maxn]; int E; int n; void add_Edge(int u, int v) { edge[++E] = Edge(u, v, head[u]); head[u] = E; } int del[maxn]; int root; int mx[maxn]; int size[maxn]; int mi; int N; lint K; lint w[maxn]; const lint mod = 1e6 + 3; pair<int, int> ans; lint rev[1000010]; void getRev() { rev[1] = 1; for(lint i = 2; i < mod; i++) rev[i] = (mod - mod / i) * rev[mod % i] % mod; } void dfs_size(int now, int father) { size[now] = 1; mx[now] = 1; for(int i = head[now]; i + 1; i = edge[i].nex) { int v = edge[i].v; if(v != father && !del[v]) { dfs_size(v, now); size[now] += size[v]; if(size[v] > mx[now]) mx[now] = size[v]; } } } void dfs_root(int r, int now, int father) { if(size[r] - size[now] > mx[now]) mx[now] = size[r] - size[now]; if(mx[now] < mi) mi = mx[now], root = now; for(int i = head[now]; i + 1; i = edge[i].nex) { int v = edge[i].v; if(v != father && !del[v]) dfs_root(r, v, now); } } lint f[1000010];//f[i]表示距离为i的字典序最小的点的标号+pre const pair<int, int> inf = make_pair(1e9, 1e9); lint pre; const lint bit = 1e5; void get(int now, int father, lint dis, int flag) { dis = dis*w[now] % mod; if(flag == 1) { //dis*x % mod = K -> x = K*dis^(mod - 2) % mod if(f[K*rev[dis] % mod] > pre) { int other = (int)(f[K*rev[dis] % mod] - pre); pair<int, int> p = other < now ? make_pair(other, now) : make_pair(now, other); if(ans > p) ans = p; } } else { if(f[dis] <= pre) f[dis] = pre + now; else f[dis] = min(f[dis], pre + now); } for(int i = head[now]; i + 1; i = edge[i].nex) { int v = edge[i].v; if(!del[v] && v != father) get(v, now, dis, flag); } } void dfs(int now) { mi = N; dfs_size(now, 0); dfs_root(now, now, 0); del[root] = 1; pre += bit; f[w[root]] = pre + root;//pre是因为每次清空f数组是不现实的, 于是用pre/bit表示第几次 for(int i = head[root]; i + 1; i = edge[i].nex) { int v = edge[i].v; if(!del[v]) { get(v, root, 1, 1); get(v, root, w[root], 0); } } for(int i = head[root]; i + 1; i = edge[i].nex) { int v = edge[i].v; if(!del[v]) dfs(v); } } void solve() { ans = inf; //pre = 0 //memset(f, 0, sizeof(f));利用pre不清0, f就可以不清0了 dfs(1); if(ans == inf) puts("No solution"); else printf("%d %d\n", ans.first, ans.second); } int main() { getRev();//预处理逆元 pre = 0; while(scanf("%d %I64d", &N, &K) != EOF) { E = 0; memset(head, -1, sizeof(head)); memset(del, 0, sizeof(del)); for(int i = 1; i <= N; i++) scanf("%I64d", &w[i]); int u, v; for(int i = 1; i < N; i++) { scanf("%d %d", &u, &v); add_Edge(u, v); add_Edge(v, u); } solve(); } return 0; }