考虑相同颜色的两种节点,这两个节点会把树分成三部分(左、中、右),左部分的点不能和右部分的点组成一种方案
枚举每一个点,只要求出有多少个点能和它组成合法点对就行了
枚举每一对颜色相同的节点,在dfs序上搞一搞就行了
#include
#include
#include
#include
using namespace std;
typedef pair<int,int> par;
typedef long long ll;
const int N=100010;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &x){
char c=nc(); x=0;
for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());
}
int n,cnt,a[N],G[N],l[N],r[N];
vector<int> c[N];
vector de[N],ad[N];
struct edge{
int t,nx;
}E[N<<1];
inline void addedge(int x,int y){
E[++cnt].t=y; E[cnt].nx=G[x]; G[x]=cnt;
E[++cnt].t=x; E[cnt].nx=G[y]; G[y]=cnt;
}
int t,fa[N][20],dpt[N];
void dfs(int x,int f){
l[x]=++t; *fa[x]=f; dpt[x]=dpt[f]+1;
for(int i=1;i<=16;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=G[x];i;i=E[i].nx)
if(E[i].t!=f) dfs(E[i].t,x);
r[x]=t;
}
int mn[N<<2],tot[N<<2],tag[N<<2];
void Build(int g,int l,int r){
tot[g]=r-l+1;
if(l==r) return ;
int mid=l+r>>1;
Build(g<<1,l,mid); Build(g<<1|1,mid+1,r);
}
inline int kfa(int x,int y){
for(int i=16;~i;i--)
if(y>>i&1) x=fa[x][i];
return x;
}
inline void add(int l1,int r1,int l2,int r2){
ad[l1].push_back(par(l2,r2));
de[r1+1].push_back(par(l2,r2));
ad[l2].push_back(par(l1,r1));
de[r2+1].push_back(par(l1,r1));
}
inline void add(int x,int y){
if(l[x]>l[y]) swap(x,y);
if(r[x]>=r[y]){
int u=kfa(y,dpt[y]-dpt[x]-1);
add(1,l[u]-1,l[y],r[y]);
if(r[u]1,n,l[y],r[y]);
}
else add(l[x],r[x],l[y],r[y]);
}
inline void Push(int g){
if(tag[g]){
tag[g<<1]+=tag[g];
tag[g<<1|1]+=tag[g];
mn[g<<1]+=tag[g];
mn[g<<1|1]+=tag[g];
tag[g]=0;
}
}
void Modify(int g,int l,int r,int L,int R,int x){
if(l==L && r==R){
mn[g]+=x; tag[g]+=x; return ;
}
int mid=L+R>>1; Push(g);
if(r<=mid) Modify(g<<1,l,r,L,mid,x);
else if(l>mid) Modify(g<<1|1,l,r,mid+1,R,x);
else Modify(g<<1,l,mid,L,mid,x),Modify(g<<1|1,mid+1,r,mid+1,R,x);
mn[g]=min(mn[g<<1],mn[g<<1|1]);
tot[g]=(mn[g<<1]==mn[g]?tot[g<<1]:0)+(mn[g<<1|1]==mn[g]?tot[g<<1|1]:0);
}
int main(){
read(n);
for(int i=1;i<=n;i++)
read(a[i]),c[a[i]].push_back(i);
for(int i=1,x,y;i1,0); Build(1,1,n);
for(int i=1;i<=n;i++){
for(int j=0;jfor(int k=j+1;k0;
for(int i=1;i<=n;i++){
for(auto u : ad[i])
Modify(1,u.first,u.second,1,n,1);
for(auto u : de[i])
Modify(1,u.first,u.second,1,n,-1);
if(!mn[1]) ans+=tot[1];
}
ans=ans+n>>1;
printf("%lld\n",ans);
return 0;
}