带权并查集
f [ i ] f[i] f[i]表示节点i所在集合, d [ i ] d[i] d[i]表示节点i到所在集合代表节点的距离。
对于一组描述,查询L,R是否在同一个集合中
1.若在同一集合中,判断是否矛盾 若: d [ r ] − d [ l ] ! = z d[r]-d[l]!=z d[r]−d[l]!=z(z代表当前描述的D)则矛盾
2.若不在同一集合中,将L,R加入到同一集合中,不妨让R所在集合并到L所在集合之下,并更新到代表元素的距离d;
f [ f r ] = f l f[fr]=fl f[fr]=fl;
d [ f r ] = d [ l ] + z − d [ r ] d[fr]=d[l]+z-d[r] d[fr]=d[l]+z−d[r];
#include<bits/stdc++.h>
using namespace std;
const int N=200005,inf=0x3f3f3f3f;
int n,m,ans;
int f[N],d[N],ma[N],mi[N];
int read(){
int sum=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return sum*f;
}
void print(int x){
if(x<0)x=-x,putchar('-');
if(x>9)print(x/10);
putchar(x%10+'0');
}
int find(int x){
if(x==f[x])return x;
int root=find(f[x]);
d[x]+=d[f[x]];
return f[x]=root;
}
int main(){
n=read();m=read();
for(int i=1;i<=n;i++)
f[i]=i,mi[i]=inf,ma[i]=-inf;
int l,r,z,fl,fr;
for(int i=1;i<=m;i++)
{
l=read();r=read();z=read();
fl=find(l);fr=find(r);
if(fl==fr){
if(d[r]-d[l]!=z){
printf("impossible");
return 0;
}
}
else{
f[fr]=fl;
d[fr]=d[l]+z-d[r];
}
}
int fa;
for(int i=1;i<=n;i++)
{
fa=find(i);
ma[fa]=max(d[i],ma[fa]);
mi[fa]=min(d[i],mi[fa]);
}
for(int i=1;i<=n;i++)
if(mi[i]!=inf)ans=max(ans,ma[i]-mi[i]);
print(ans);
return 0;
}
spfa最短路+乘法原理
先spfa跑所有点到起点的最短路d[]
f[i]表示从起点到当前点i走最短路的方案数。g[i]表示从当前点i到终点走最短路的方案数。
从d[]由小到大排序,求出所有点的f[]值;从d[]由大到小排序,求出所有点的g[]值。
不考虑相遇情况,走最短路的方案有f[t]*g[s]种(s,t分别代表起点,终点)
相遇有两种情况:
1.在某点上相遇。很容易得到这个点i满足d[i]=d[t]/2 (即到起点的距离为全程最短路的一半)
ans=ans-(f[i] * g[i])^2
2.在某边上相遇。意味着这条边上存在某个点也满足到起点的距离为全程最短路的一半。转化成边上两端点x,y满足:d[x]
ans=ans-(f[x] * g[y])^2
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=200005,M=400005,inf=0x3f3f3f3f,mod=1000000007;
int n,m,s,t,tot;
int head[N],id[N],v[N];
ll f[N],g[N],ans,d[N];
struct edge{
int to,v,w;
}e[M];
void add(int x,int y,int z){
e[++tot].v=y;
e[tot].to=head[x];
e[tot].w=z;
head[x]=tot;
}
ll read(){
ll sum=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return sum*f;
}
void print(ll x){
if(x<0)x=-x,putchar('-');
if(x>9)print(x/10);
putchar(x%10+'0');
}
void spfa(){
queue<int>q;
q.push(s);
for(int i=1;i<=n;i++)
d[i]=inf,id[i]=i;
d[s]=0;
while(q.size()){
int x=q.front();q.pop();
v[x]=0;
for(int i=head[x];i;i=e[i].to)
{
int y=e[i].v;
if(d[x]+e[i].w<d[y]){
d[y]=d[x]+e[i].w;
if(!v[y])v[y]=1,q.push(y);
}
}
}
}
bool comp(int a,int b)
{
return d[a]<d[b];
}
void solve(){
sort(id+1,id+n+1,comp);
f[s]=1;g[t]=1;
for(int i=1;i<=n;i++)
{
int x=id[i];
for(int i=head[x];i;i=e[i].to)
{
int y=e[i].v;
if(d[y]==d[x]+e[i].w){
f[y]=(f[y]+f[x])%mod;
}
}
}
for(int i=n;i>=1;i--)
{
int x=id[i];
for(int i=head[x];i;i=e[i].to)
{
int y=e[i].v;
if(d[y]+e[i].w==d[x])g[y]=(g[y]+g[x])%mod;
}
}
ans=1ll*g[s]*f[t]%mod;
for(int x=1;x<=n;x++)
{
for(int i=head[x];i;i=e[i].to)
{
int y=e[i].v;
if(d[y]==d[x]+e[i].w)
if(1ll*d[y]*2>d[t]&&1ll*d[x]*2<d[t])ans=(ans-1ll*f[x]*f[x]%mod*g[y]%mod*g[y]%mod)%mod;
}
}
for(int i=1;i<=n;i++)
{
if(1ll*d[i]*2==d[t])ans=(ans-1ll*f[i]*f[i]%mod*g[i]%mod*g[i]%mod)%mod;
}
ans=(ans+mod)%mod;
}
int main(){
n=read();m=read();
s=read();t=read();
int u,v,z;
for(int i=1;i<=m;i++)
{
u=read();v=read();z=read();
add(u,v,z);add(v,u,z);
}
spfa();
solve();
print(ans);
return 0;
}
DFS序(欧拉序) + 树状数组维护链
欧拉序:从根结点出发,按dfs的顺序在绕回原点所经过所有点的顺序。
先将原来的树用欧拉序排序。
两条链相交的充要条件:
一条链上的深度最浅的点(就是LCA)在另一条链上
于是用树状数组的维护欧拉序上某一条链上有多少个其他链的LCA。
将每一条链上的其他链的LCA数量加起来再减去会重复的每个LCA数量(设为n)的平方与
C(2,n)(组合数)即为答案。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 2e6 + 5;
int n,m,head[N<<1],head2[N<<1],cnt=0;
bool vis[N<<1];
int dep[N<<1],f[N],fa[N],tot[N],into[N],out[N],sum[N];
ll ans=0;
struct qaq
{
int u,v,up;
}node[N<<1];
struct edge
{
int v,to,id;
}a[N<<1],e[N<<1];
void add(int u,int v)
{
e[++cnt].v=head[u]; e[cnt].to=v; head[u]=cnt;
e[++cnt].v=head[v]; e[cnt].to=u; head[v]=cnt;
}
void add2(int u,int v,int id)
{
a[++cnt].v=head2[u]; a[cnt].to=v; a[cnt].id=id;head2[u]=cnt;
a[++cnt].v=head2[v]; a[cnt].to=u; a[cnt].id=id;head2[v]=cnt;
}
int read(){
int sum=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return sum*f;
}
void print(int x){
if(x<0)x=-x,putchar('-');
if(x>9)print(x/10);
putchar(x%10+'0');
}
int lowbit(int x) {return x &(-x);}
void update(int x,int val)
{
while(x<=n*2)
{
sum[x]+=val;
x+=lowbit(x);
}
}
int ask(int x)
{
int ans=0;
while(x)
{
ans+=sum[x];
x-=lowbit(x);
}
return ans;
}
int find(int x) {return x == f[x] ? x : f[x] = find(f[x]);}
void dfs(int x,int pre)
{
into[x]=++cnt;
dep[x]=dep[pre] + 1;
fa[x]=pre;
vis[x]=1;
for(int i=head2[x];i;i=a[i].v)
{
int y=a[i].to;
if(vis[y])node[a[i].id].up=find(y);
}
for(int i=head[x];i;i=e[i].v)
{
int y=e[i].to;
if(y==pre)continue;
dfs(y,x);
}
out[x]=++cnt;
f[x]=pre;
}
ll C(int n)
{
if(n < 2) return 0;
return (1ll*n*(n-1)) >> 1;
}
int main()
{
n=read(); m=read();
int x,y;
for(int i = 1;i < n;i++)
{
x=read();y=read();
add(x,y);
}
cnt=0;
for(int i = 1;i <= m;i++)
{
x=read();y=read();
node[i].u=x;node[i].v=y;
add2(x,y,i);
add2(y,x,i);
}
for(int i=1;i<=n;i++)f[i]=i;
cnt=0; dfs(1,0);
for(int i = 1;i <= m;i++)
{
update(into[node[i].up],1);
update(out[node[i].up],-1);
tot[node[i].up]++;
}
for(int i = 1;i <= m;i++)
{
ans=ans+1ll*ask(into[node[i].u])+1ll*ask(into[node[i].v]);
ans=ans-1ll*ask(into[node[i].up])-1ll*ask(into[fa[node[i].up]]);
}
for(int i = 1;i <= n;i++)
ans = ans - 1ll * tot[i] * tot[i] + C(tot[i]);
print(ans);
return 0;
}