我奇特的脑回路的做法就是
树链剖分 + 树状数组
树状数组是那种 区间修改,区间求和,还有回溯的
当我看到别人写的是lca,直接讨论时,感觉自己的智商收到了碾压。。。
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
const int N = 1e5+5;
#define MS(x,y) memset(x,y,sizeof(x))
#define MP(x, y) make_pair(x, y)
const int INF = 0x3f3f3f3f;
int n, q;
struct Node{
int to,next;
}edge[N << 1];
int tot;
int head[N];
void addedge(int x,int y){
edge[tot].to = y; edge[tot].next = head[x]; head[x] = tot ++;
}
int top[N],fa[N],son[N],deep[N],num[N],p[N],fp[N],pos;
void dfs1(int x,int pre,int dep){
deep[x] = dep;
fa[x] = pre;
num[x] = 1;
for(int i = head[x]; i != -1; i = edge[i].next){
int y = edge[i].to;
if(y == pre) continue;
dfs1(y,x,dep + 1);
num[x] += num[y];
if(son[x] == -1 || num[y] > num[son[x]])
son[x] = y;
}
}
void dfs2(int x,int tp){
top[x] = tp;
p[x] = pos++;
fp[p[x]] = x;
if(son[x] == -1) return;
dfs2(son[x],tp);
for(int i = head[x] ; i != -1; i = edge[i].next){
int y = edge[i].to;
if(y != son[x] && y != fa[x])
dfs2(y,y);
}
}
ll tree1[N]; ll tree2[N];
void Add(ll tree[], int pos, int val) {
for(int i = pos; i <= n; i += i&-i) {
tree[i] += val;
}
}
ll Sum(ll tree[], int pos) {
if(pos == 0) return 0;
ll ans = 0;
for(int i = pos; i; i -= i&-i) {
ans += tree[i];
}
return ans;
}
vectorint , int> > Resume;
void Find(int x, int y) {
int fx = top[x]; int fy = top[y];
while(fx != fy) {
if(deep[fx] < deep[fy]) {
swap(fx, fy);
swap(x, y);
}
Add(tree1, p[fx], 1);
Add(tree1, p[x]+1, -1);
Add(tree2, p[fx], p[fx]);
Add(tree2, p[x]+1, -p[x]-1);
Resume.push_back(MP(p[fx], -1));
Resume.push_back(MP(p[x]+1, 1));
Resume.push_back(MP(-p[fx], -p[fx]));
Resume.push_back(MP(-p[x]-1, p[x]+1));
x = fa[fx];
fx = top[x];
}
if(deep[x] > deep[y]) swap(x, y);
Add(tree1, p[x], 1);
Add(tree1, p[y]+1, -1);
Add(tree2, p[x], p[x]);
Add(tree2, p[y]+1, -p[y]-1);
Resume.push_back(MP(p[x], -1));
Resume.push_back(MP(p[y]+1, 1));
Resume.push_back(MP(-p[x], -p[x]));
Resume.push_back(MP(-p[y]-1, p[y]+1));
}
ll Total(int x, int y) {
ll sum = 0;
int fx = top[x]; int fy = top[y];
while(fx != fy) {
if(deep[fx] < deep[fy]) {
swap(fx, fy);
swap(x, y);
}
sum += 1ll*(p[x]+1)*Sum(tree1, p[x]) - Sum(tree2, p[x]) - 1ll*(p[fx])*Sum(tree1, p[fx]-1) + Sum(tree2, p[fx]-1);
x = fa[fx];
fx = top[x];
}
if(deep[x] > deep[y]) swap(x, y);
sum += 1ll*(p[y]+1)*Sum(tree1, p[y]) - Sum(tree2, p[y]) - 1ll*(p[x])*Sum(tree1, p[x]-1) + Sum(tree2, p[x]-1);
return sum;
}
ll solve(int a,int b, int c,int d) {
Resume.clear();
Find(a, b);
// for(int i = 1; i <= n; ++i) printf("%d ", Sum(i)); printf("\n");
ll tt = Total(c, d);
// printf("hh %d %d %d %d %d\n",a,b,c,d, tt);
for(int i = 0; i < Resume.size(); ++i) {
if(Resume[i].first > 0) Add(tree1, Resume[i].first, Resume[i].second);
else Add(tree2, -Resume[i].first, Resume[i].second);
}
return tt;
}
int main() {
while(~scanf("%d %d", &n, &q)) {
MS(tree1, 0); MS(tree2, 0);
memset(head, -1, sizeof(head));
memset(son, -1, sizeof(son));
tot = 0; pos = 1;
for(int i = 2; i <= n; ++i) {
int a; scanf("%d", &a);
addedge(a, i); addedge(i, a);
}
dfs1(1, 1, 1);
dfs2(1, 1);
// for(int i = 1; i <= n; ++i) printf("%d ", p[i]); printf("\n");
for(int i = 0; i < q; ++i) {
int a, b, c; scanf("%d %d %d", &a, &b, &c);
ll ans = max(max(solve(a,b, a,c), solve(a,b, b,c)), solve(a,c, b,c));
printf("%lld\n", ans);
}
}
return 0;
}